mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 15:39:47 +00:00
Merge branch 'main' into in-out-tree-provider-guide
This commit is contained in:
commit
99f6ee3d6c
159 changed files with 13342 additions and 8254 deletions
20
.github/actions/run-and-record-tests/action.yml
vendored
20
.github/actions/run-and-record-tests/action.yml
vendored
|
|
@ -2,9 +2,13 @@ name: 'Run and Record Tests'
|
||||||
description: 'Run integration tests and handle recording/artifact upload'
|
description: 'Run integration tests and handle recording/artifact upload'
|
||||||
|
|
||||||
inputs:
|
inputs:
|
||||||
test-types:
|
test-subdirs:
|
||||||
description: 'JSON array of test types to run'
|
description: 'Comma-separated list of test subdirectories to run'
|
||||||
required: true
|
required: true
|
||||||
|
test-pattern:
|
||||||
|
description: 'Regex pattern to pass to pytest -k'
|
||||||
|
required: false
|
||||||
|
default: ''
|
||||||
stack-config:
|
stack-config:
|
||||||
description: 'Stack configuration to use'
|
description: 'Stack configuration to use'
|
||||||
required: true
|
required: true
|
||||||
|
|
@ -35,9 +39,11 @@ runs:
|
||||||
./scripts/integration-tests.sh \
|
./scripts/integration-tests.sh \
|
||||||
--stack-config '${{ inputs.stack-config }}' \
|
--stack-config '${{ inputs.stack-config }}' \
|
||||||
--provider '${{ inputs.provider }}' \
|
--provider '${{ inputs.provider }}' \
|
||||||
--test-types '${{ inputs.test-types }}' \
|
--test-subdirs '${{ inputs.test-subdirs }}' \
|
||||||
|
--test-pattern '${{ inputs.test-pattern }}' \
|
||||||
--inference-mode '${{ inputs.inference-mode }}' \
|
--inference-mode '${{ inputs.inference-mode }}' \
|
||||||
${{ inputs.run-vision-tests == 'true' && '--run-vision-tests' || '' }}
|
${{ inputs.run-vision-tests == 'true' && '--run-vision-tests' || '' }} \
|
||||||
|
| tee pytest-${{ inputs.inference-mode }}.log
|
||||||
|
|
||||||
|
|
||||||
- name: Commit and push recordings
|
- name: Commit and push recordings
|
||||||
|
|
@ -57,10 +63,10 @@ runs:
|
||||||
git commit -m "Recordings update from CI"
|
git commit -m "Recordings update from CI"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
git fetch origin ${{ github.event.pull_request.head.ref }}
|
git fetch origin ${{ github.ref_name }}
|
||||||
git rebase origin/${{ github.event.pull_request.head.ref }}
|
git rebase origin/${{ github.ref_name }}
|
||||||
echo "Rebased successfully"
|
echo "Rebased successfully"
|
||||||
git push origin HEAD:${{ github.event.pull_request.head.ref }}
|
git push origin HEAD:${{ github.ref_name }}
|
||||||
echo "Pushed successfully"
|
echo "Pushed successfully"
|
||||||
else
|
else
|
||||||
echo "No recording changes"
|
echo "No recording changes"
|
||||||
|
|
|
||||||
35
.github/workflows/integration-tests.yml
vendored
35
.github/workflows/integration-tests.yml
vendored
|
|
@ -5,7 +5,7 @@ run-name: Run the integration test suite from tests/integration in replay mode
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
pull_request_target:
|
pull_request:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
types: [opened, synchronize, reopened]
|
types: [opened, synchronize, reopened]
|
||||||
paths:
|
paths:
|
||||||
|
|
@ -31,35 +31,23 @@ on:
|
||||||
description: 'Test against a specific provider'
|
description: 'Test against a specific provider'
|
||||||
type: string
|
type: string
|
||||||
default: 'ollama'
|
default: 'ollama'
|
||||||
|
test-subdirs:
|
||||||
|
description: 'Comma-separated list of test subdirectories to run'
|
||||||
|
type: string
|
||||||
|
default: ''
|
||||||
|
test-pattern:
|
||||||
|
description: 'Regex pattern to pass to pytest -k'
|
||||||
|
type: string
|
||||||
|
default: ''
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
# Skip concurrency for pushes to main - each commit should be tested independently
|
# Skip concurrency for pushes to main - each commit should be tested independently
|
||||||
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.event.pull_request.number }}
|
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
discover-tests:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
outputs:
|
|
||||||
test-types: ${{ steps.generate-test-types.outputs.test-types }}
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
|
||||||
|
|
||||||
- name: Generate test types
|
|
||||||
id: generate-test-types
|
|
||||||
run: |
|
|
||||||
# Get test directories dynamically, excluding non-test directories
|
|
||||||
# NOTE: we are excluding post_training since the tests take too long
|
|
||||||
TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d |
|
|
||||||
sed 's|tests/integration/||' |
|
|
||||||
grep -Ev "^(__pycache__|fixtures|test_cases|recordings|non_ci|post_training)$" |
|
|
||||||
sort | jq -R -s -c 'split("\n")[:-1]')
|
|
||||||
echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT
|
|
||||||
|
|
||||||
run-replay-mode-tests:
|
run-replay-mode-tests:
|
||||||
needs: discover-tests
|
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, vision={4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.run-vision-tests) }}
|
name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, vision={4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.run-vision-tests) }}
|
||||||
|
|
||||||
|
|
@ -90,7 +78,8 @@ jobs:
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
uses: ./.github/actions/run-and-record-tests
|
uses: ./.github/actions/run-and-record-tests
|
||||||
with:
|
with:
|
||||||
test-types: ${{ needs.discover-tests.outputs.test-types }}
|
test-subdirs: ${{ inputs.test-subdirs }}
|
||||||
|
test-pattern: ${{ inputs.test-pattern }}
|
||||||
stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }}
|
stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }}
|
||||||
provider: ${{ matrix.provider }}
|
provider: ${{ matrix.provider }}
|
||||||
inference-mode: 'replay'
|
inference-mode: 'replay'
|
||||||
|
|
|
||||||
|
|
@ -14,9 +14,11 @@ on:
|
||||||
- 'pyproject.toml'
|
- 'pyproject.toml'
|
||||||
- 'requirements.txt'
|
- 'requirements.txt'
|
||||||
- '.github/workflows/integration-vector-io-tests.yml' # This workflow
|
- '.github/workflows/integration-vector-io-tests.yml' # This workflow
|
||||||
|
schedule:
|
||||||
|
- cron: '0 0 * * *' # (test on python 3.13) Daily at 12 AM UTC
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|
@ -25,7 +27,7 @@ jobs:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::weaviate", "remote::qdrant"]
|
vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::weaviate", "remote::qdrant"]
|
||||||
python-version: ["3.12", "3.13"]
|
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
|
||||||
fail-fast: false # we want to run all tests regardless of failure
|
fail-fast: false # we want to run all tests regardless of failure
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
|
|
|
||||||
101
.github/workflows/record-integration-tests.yml
vendored
101
.github/workflows/record-integration-tests.yml
vendored
|
|
@ -1,93 +1,53 @@
|
||||||
|
# This workflow should be run manually when needing to re-record tests. This happens when you have
|
||||||
|
# - added a new test
|
||||||
|
# - or changed an existing test such that a new inference call is made
|
||||||
|
# You should make a PR and then run this workflow on that PR branch. The workflow will re-record the
|
||||||
|
# tests and commit the recordings to the PR branch.
|
||||||
name: Integration Tests (Record)
|
name: Integration Tests (Record)
|
||||||
|
|
||||||
run-name: Run the integration test suite from tests/integration
|
run-name: Run the integration test suite from tests/integration
|
||||||
|
|
||||||
on:
|
on:
|
||||||
pull_request:
|
|
||||||
branches: [ main ]
|
|
||||||
types: [opened, synchronize, labeled]
|
|
||||||
paths:
|
|
||||||
- 'llama_stack/**'
|
|
||||||
- 'tests/**'
|
|
||||||
- 'uv.lock'
|
|
||||||
- 'pyproject.toml'
|
|
||||||
- '.github/workflows/record-integration-tests.yml' # This workflow
|
|
||||||
- '.github/actions/setup-ollama/action.yml'
|
|
||||||
- '.github/actions/setup-test-environment/action.yml'
|
|
||||||
- '.github/actions/run-and-record-tests/action.yml'
|
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
inputs:
|
inputs:
|
||||||
|
test-subdirs:
|
||||||
|
description: 'Comma-separated list of test subdirectories to run'
|
||||||
|
type: string
|
||||||
|
default: ''
|
||||||
test-provider:
|
test-provider:
|
||||||
description: 'Test against a specific provider'
|
description: 'Test against a specific provider'
|
||||||
type: string
|
type: string
|
||||||
default: 'ollama'
|
default: 'ollama'
|
||||||
|
run-vision-tests:
|
||||||
concurrency:
|
description: 'Whether to run vision tests'
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
type: boolean
|
||||||
cancel-in-progress: true
|
default: false
|
||||||
|
test-pattern:
|
||||||
|
description: 'Regex pattern to pass to pytest -k'
|
||||||
|
type: string
|
||||||
|
default: ''
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
discover-tests:
|
|
||||||
if: contains(github.event.pull_request.labels.*.name, 're-record-tests') ||
|
|
||||||
contains(github.event.pull_request.labels.*.name, 're-record-vision-tests')
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
outputs:
|
|
||||||
test-types: ${{ steps.generate-test-types.outputs.test-types }}
|
|
||||||
matrix-modes: ${{ steps.generate-test-types.outputs.matrix-modes }}
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
|
||||||
|
|
||||||
- name: Generate test types
|
|
||||||
id: generate-test-types
|
|
||||||
run: |
|
|
||||||
# Get test directories dynamically, excluding non-test directories
|
|
||||||
TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d -printf "%f\n" |
|
|
||||||
grep -Ev "^(__pycache__|fixtures|test_cases|recordings|post_training)$" |
|
|
||||||
sort | jq -R -s -c 'split("\n")[:-1]')
|
|
||||||
echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT
|
|
||||||
|
|
||||||
labels=$(gh pr view ${{ github.event.pull_request.number }} --json labels --jq '.labels[].name')
|
|
||||||
echo "labels=$labels"
|
|
||||||
|
|
||||||
modes_array=()
|
|
||||||
if [[ $labels == *"re-record-vision-tests"* ]]; then
|
|
||||||
modes_array+=("vision")
|
|
||||||
fi
|
|
||||||
if [[ $labels == *"re-record-tests"* ]]; then
|
|
||||||
modes_array+=("non-vision")
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Convert to JSON array
|
|
||||||
if [ ${#modes_array[@]} -eq 0 ]; then
|
|
||||||
matrix_modes="[]"
|
|
||||||
else
|
|
||||||
matrix_modes=$(printf '%s\n' "${modes_array[@]}" | jq -R -s -c 'split("\n")[:-1]')
|
|
||||||
fi
|
|
||||||
echo "matrix_modes=$matrix_modes"
|
|
||||||
echo "matrix-modes=$matrix_modes" >> $GITHUB_OUTPUT
|
|
||||||
|
|
||||||
env:
|
|
||||||
GH_TOKEN: ${{ github.token }}
|
|
||||||
|
|
||||||
record-tests:
|
record-tests:
|
||||||
needs: discover-tests
|
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
mode: ${{ fromJSON(needs.discover-tests.outputs.matrix-modes) }}
|
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
|
- name: Echo workflow inputs
|
||||||
|
run: |
|
||||||
|
echo "::group::Workflow Inputs"
|
||||||
|
echo "test-subdirs: ${{ inputs.test-subdirs }}"
|
||||||
|
echo "test-provider: ${{ inputs.test-provider }}"
|
||||||
|
echo "run-vision-tests: ${{ inputs.run-vision-tests }}"
|
||||||
|
echo "test-pattern: ${{ inputs.test-pattern }}"
|
||||||
|
echo "branch: ${{ github.ref_name }}"
|
||||||
|
echo "::endgroup::"
|
||||||
|
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
with:
|
with:
|
||||||
ref: ${{ github.event.pull_request.head.ref }}
|
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Setup test environment
|
- name: Setup test environment
|
||||||
|
|
@ -96,14 +56,15 @@ jobs:
|
||||||
python-version: "3.12" # Use single Python version for recording
|
python-version: "3.12" # Use single Python version for recording
|
||||||
client-version: "latest"
|
client-version: "latest"
|
||||||
provider: ${{ inputs.test-provider || 'ollama' }}
|
provider: ${{ inputs.test-provider || 'ollama' }}
|
||||||
run-vision-tests: ${{ matrix.mode == 'vision' && 'true' || 'false' }}
|
run-vision-tests: ${{ inputs.run-vision-tests }}
|
||||||
inference-mode: 'record'
|
inference-mode: 'record'
|
||||||
|
|
||||||
- name: Run and record tests
|
- name: Run and record tests
|
||||||
uses: ./.github/actions/run-and-record-tests
|
uses: ./.github/actions/run-and-record-tests
|
||||||
with:
|
with:
|
||||||
test-types: ${{ needs.discover-tests.outputs.test-types }}
|
test-pattern: ${{ inputs.test-pattern }}
|
||||||
|
test-subdirs: ${{ inputs.test-subdirs }}
|
||||||
stack-config: 'server:ci-tests' # recording must be done with server since more tests are run
|
stack-config: 'server:ci-tests' # recording must be done with server since more tests are run
|
||||||
provider: ${{ inputs.test-provider || 'ollama' }}
|
provider: ${{ inputs.test-provider || 'ollama' }}
|
||||||
inference-mode: 'record'
|
inference-mode: 'record'
|
||||||
run-vision-tests: ${{ matrix.mode == 'vision' && 'true' || 'false' }}
|
run-vision-tests: ${{ inputs.run-vision-tests }}
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ exclude: 'build/'
|
||||||
|
|
||||||
default_language_version:
|
default_language_version:
|
||||||
python: python3.12
|
python: python3.12
|
||||||
|
node: "22"
|
||||||
|
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
|
@ -145,6 +146,20 @@ repos:
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
files: ^.github/workflows/.*$
|
files: ^.github/workflows/.*$
|
||||||
|
- id: ui-prettier
|
||||||
|
name: Format UI code with Prettier
|
||||||
|
entry: bash -c 'cd llama_stack/ui && npm run format'
|
||||||
|
language: system
|
||||||
|
files: ^llama_stack/ui/.*\.(ts|tsx)$
|
||||||
|
pass_filenames: false
|
||||||
|
require_serial: true
|
||||||
|
- id: ui-eslint
|
||||||
|
name: Lint UI code with ESLint
|
||||||
|
entry: bash -c 'cd llama_stack/ui && npm run lint -- --fix --quiet'
|
||||||
|
language: system
|
||||||
|
files: ^llama_stack/ui/.*\.(ts|tsx)$
|
||||||
|
pass_filenames: false
|
||||||
|
require_serial: true
|
||||||
|
|
||||||
ci:
|
ci:
|
||||||
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
||||||
|
|
|
||||||
6
docs/_static/llama-stack-spec.html
vendored
6
docs/_static/llama-stack-spec.html
vendored
|
|
@ -14767,7 +14767,8 @@
|
||||||
"OpenAIFilePurpose": {
|
"OpenAIFilePurpose": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": [
|
"enum": [
|
||||||
"assistants"
|
"assistants",
|
||||||
|
"batch"
|
||||||
],
|
],
|
||||||
"title": "OpenAIFilePurpose",
|
"title": "OpenAIFilePurpose",
|
||||||
"description": "Valid purpose values for OpenAI Files API."
|
"description": "Valid purpose values for OpenAI Files API."
|
||||||
|
|
@ -14844,7 +14845,8 @@
|
||||||
"purpose": {
|
"purpose": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": [
|
"enum": [
|
||||||
"assistants"
|
"assistants",
|
||||||
|
"batch"
|
||||||
],
|
],
|
||||||
"description": "The intended purpose of the file"
|
"description": "The intended purpose of the file"
|
||||||
}
|
}
|
||||||
|
|
|
||||||
2
docs/_static/llama-stack-spec.yaml
vendored
2
docs/_static/llama-stack-spec.yaml
vendored
|
|
@ -10951,6 +10951,7 @@ components:
|
||||||
type: string
|
type: string
|
||||||
enum:
|
enum:
|
||||||
- assistants
|
- assistants
|
||||||
|
- batch
|
||||||
title: OpenAIFilePurpose
|
title: OpenAIFilePurpose
|
||||||
description: >-
|
description: >-
|
||||||
Valid purpose values for OpenAI Files API.
|
Valid purpose values for OpenAI Files API.
|
||||||
|
|
@ -11019,6 +11020,7 @@ components:
|
||||||
type: string
|
type: string
|
||||||
enum:
|
enum:
|
||||||
- assistants
|
- assistants
|
||||||
|
- batch
|
||||||
description: The intended purpose of the file
|
description: The intended purpose of the file
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
|
|
|
||||||
|
|
@ -18,3 +18,4 @@ We are working on adding a few more APIs to complete the application lifecycle.
|
||||||
- **Batch Inference**: run inference on a dataset of inputs
|
- **Batch Inference**: run inference on a dataset of inputs
|
||||||
- **Batch Agents**: run agents on a dataset of inputs
|
- **Batch Agents**: run agents on a dataset of inputs
|
||||||
- **Synthetic Data Generation**: generate synthetic data for model development
|
- **Synthetic Data Generation**: generate synthetic data for model development
|
||||||
|
- **Batches**: OpenAI-compatible batch management for inference
|
||||||
|
|
|
||||||
|
|
@ -4,11 +4,11 @@
|
||||||
|
|
||||||
## Adding a New Provider
|
## Adding a New Provider
|
||||||
|
|
||||||
See the [Adding a New API Provider Page](new_api_provider.md) which describes how to add new API providers to the Stack.
|
See:
|
||||||
|
- [Adding a New API Provider Page](new_api_provider.md) which describes how to add new API providers to the Stack.
|
||||||
|
- [Vector Database Page](new_vector_database.md) which describes how to add a new vector databases with Llama Stack.
|
||||||
|
- [External Provider Page](../providers/external/index.md) which describes how to add external providers to the Stack.
|
||||||
|
|
||||||
See the [Vector Database Page](new_vector_database.md) which describes how to add a new vector databases with Llama Stack.
|
|
||||||
|
|
||||||
See the [External Provider Page](../providers/external/index.md) which describes how to add external providers to the Stack.
|
|
||||||
```{toctree}
|
```{toctree}
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
:hidden:
|
:hidden:
|
||||||
|
|
@ -19,11 +19,21 @@ new_vector_database
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
See the [Test Page](testing.md) which describes how to test your changes.
|
|
||||||
|
```{include} ../../../tests/README.md
|
||||||
|
```
|
||||||
|
|
||||||
|
## Advanced Topics
|
||||||
|
|
||||||
|
For developers who need deeper understanding of the testing system internals:
|
||||||
|
|
||||||
```{toctree}
|
```{toctree}
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
:hidden:
|
|
||||||
:caption: Testing
|
|
||||||
|
|
||||||
testing
|
testing/record-replay
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Benchmarking
|
||||||
|
|
||||||
|
```{include} ../../../docs/source/distributions/k8s-benchmark/README.md
|
||||||
|
```
|
||||||
|
|
|
||||||
|
|
@ -1,8 +0,0 @@
|
||||||
```{include} ../../../tests/README.md
|
|
||||||
```
|
|
||||||
|
|
||||||
```{include} ../../../tests/unit/README.md
|
|
||||||
```
|
|
||||||
|
|
||||||
```{include} ../../../tests/integration/README.md
|
|
||||||
```
|
|
||||||
234
docs/source/contributing/testing/record-replay.md
Normal file
234
docs/source/contributing/testing/record-replay.md
Normal file
|
|
@ -0,0 +1,234 @@
|
||||||
|
# Record-Replay System
|
||||||
|
|
||||||
|
Understanding how Llama Stack captures and replays API interactions for testing.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The record-replay system solves a fundamental challenge in AI testing: how do you test against expensive, non-deterministic APIs without breaking the bank or dealing with flaky tests?
|
||||||
|
|
||||||
|
The solution: intercept API calls, store real responses, and replay them later. This gives you real API behavior without the cost or variability.
|
||||||
|
|
||||||
|
## How It Works
|
||||||
|
|
||||||
|
### Request Hashing
|
||||||
|
|
||||||
|
Every API request gets converted to a deterministic hash for lookup:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def normalize_request(method: str, url: str, headers: dict, body: dict) -> str:
|
||||||
|
normalized = {
|
||||||
|
"method": method.upper(),
|
||||||
|
"endpoint": urlparse(url).path, # Just the path, not full URL
|
||||||
|
"body": body, # Request parameters
|
||||||
|
}
|
||||||
|
return hashlib.sha256(json.dumps(normalized, sort_keys=True).encode()).hexdigest()
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key insight:** The hashing is intentionally precise. Different whitespace, float precision, or parameter order produces different hashes. This prevents subtle bugs from false cache hits.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# These produce DIFFERENT hashes:
|
||||||
|
{"content": "Hello world"}
|
||||||
|
{"content": "Hello world\n"}
|
||||||
|
{"temperature": 0.7}
|
||||||
|
{"temperature": 0.7000001}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Client Interception
|
||||||
|
|
||||||
|
The system patches OpenAI and Ollama client methods to intercept calls before they leave your application. This happens transparently - your test code doesn't change.
|
||||||
|
|
||||||
|
### Storage Architecture
|
||||||
|
|
||||||
|
Recordings use a two-tier storage system optimized for both speed and debuggability:
|
||||||
|
|
||||||
|
```
|
||||||
|
recordings/
|
||||||
|
├── index.sqlite # Fast lookup by request hash
|
||||||
|
└── responses/
|
||||||
|
├── abc123def456.json # Individual response files
|
||||||
|
└── def789ghi012.json
|
||||||
|
```
|
||||||
|
|
||||||
|
**SQLite index** enables O(log n) hash lookups and metadata queries without loading response bodies.
|
||||||
|
|
||||||
|
**JSON files** store complete request/response pairs in human-readable format for debugging.
|
||||||
|
|
||||||
|
## Recording Modes
|
||||||
|
|
||||||
|
### LIVE Mode
|
||||||
|
|
||||||
|
Direct API calls with no recording or replay:
|
||||||
|
|
||||||
|
```python
|
||||||
|
with inference_recording(mode=InferenceMode.LIVE):
|
||||||
|
response = await client.chat.completions.create(...)
|
||||||
|
```
|
||||||
|
|
||||||
|
Use for initial development and debugging against real APIs.
|
||||||
|
|
||||||
|
### RECORD Mode
|
||||||
|
|
||||||
|
Captures API interactions while passing through real responses:
|
||||||
|
|
||||||
|
```python
|
||||||
|
with inference_recording(mode=InferenceMode.RECORD, storage_dir="./recordings"):
|
||||||
|
response = await client.chat.completions.create(...)
|
||||||
|
# Real API call made, response captured AND returned
|
||||||
|
```
|
||||||
|
|
||||||
|
The recording process:
|
||||||
|
1. Request intercepted and hashed
|
||||||
|
2. Real API call executed
|
||||||
|
3. Response captured and serialized
|
||||||
|
4. Recording stored to disk
|
||||||
|
5. Original response returned to caller
|
||||||
|
|
||||||
|
### REPLAY Mode
|
||||||
|
|
||||||
|
Returns stored responses instead of making API calls:
|
||||||
|
|
||||||
|
```python
|
||||||
|
with inference_recording(mode=InferenceMode.REPLAY, storage_dir="./recordings"):
|
||||||
|
response = await client.chat.completions.create(...)
|
||||||
|
# No API call made, cached response returned instantly
|
||||||
|
```
|
||||||
|
|
||||||
|
The replay process:
|
||||||
|
1. Request intercepted and hashed
|
||||||
|
2. Hash looked up in SQLite index
|
||||||
|
3. Response loaded from JSON file
|
||||||
|
4. Response deserialized and returned
|
||||||
|
5. Error if no recording found
|
||||||
|
|
||||||
|
## Streaming Support
|
||||||
|
|
||||||
|
Streaming APIs present a unique challenge: how do you capture an async generator?
|
||||||
|
|
||||||
|
### The Problem
|
||||||
|
|
||||||
|
```python
|
||||||
|
# How do you record this?
|
||||||
|
async for chunk in client.chat.completions.create(stream=True):
|
||||||
|
process(chunk)
|
||||||
|
```
|
||||||
|
|
||||||
|
### The Solution
|
||||||
|
|
||||||
|
The system captures all chunks immediately before yielding any:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def handle_streaming_record(response):
|
||||||
|
# Capture complete stream first
|
||||||
|
chunks = []
|
||||||
|
async for chunk in response:
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
# Store complete recording
|
||||||
|
storage.store_recording(
|
||||||
|
request_hash, request_data, {"body": chunks, "is_streaming": True}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return generator that replays captured chunks
|
||||||
|
async def replay_stream():
|
||||||
|
for chunk in chunks:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return replay_stream()
|
||||||
|
```
|
||||||
|
|
||||||
|
This ensures:
|
||||||
|
- **Complete capture** - The entire stream is saved atomically
|
||||||
|
- **Interface preservation** - The returned object behaves like the original API
|
||||||
|
- **Deterministic replay** - Same chunks in the same order every time
|
||||||
|
|
||||||
|
## Serialization
|
||||||
|
|
||||||
|
API responses contain complex Pydantic objects that need careful serialization:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _serialize_response(response):
|
||||||
|
if hasattr(response, "model_dump"):
|
||||||
|
# Preserve type information for proper deserialization
|
||||||
|
return {
|
||||||
|
"__type__": f"{response.__class__.__module__}.{response.__class__.__qualname__}",
|
||||||
|
"__data__": response.model_dump(mode="json"),
|
||||||
|
}
|
||||||
|
return response
|
||||||
|
```
|
||||||
|
|
||||||
|
This preserves type safety - when replayed, you get the same Pydantic objects with all their validation and methods.
|
||||||
|
|
||||||
|
## Environment Integration
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
Control recording behavior globally:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export LLAMA_STACK_TEST_INFERENCE_MODE=replay
|
||||||
|
export LLAMA_STACK_TEST_RECORDING_DIR=/path/to/recordings
|
||||||
|
pytest tests/integration/
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pytest Integration
|
||||||
|
|
||||||
|
The system integrates automatically based on environment variables, requiring no changes to test code.
|
||||||
|
|
||||||
|
## Debugging Recordings
|
||||||
|
|
||||||
|
### Inspecting Storage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# See what's recorded
|
||||||
|
sqlite3 recordings/index.sqlite "SELECT endpoint, model, timestamp FROM recordings LIMIT 10;"
|
||||||
|
|
||||||
|
# View specific response
|
||||||
|
cat recordings/responses/abc123def456.json | jq '.response.body'
|
||||||
|
|
||||||
|
# Find recordings by endpoint
|
||||||
|
sqlite3 recordings/index.sqlite "SELECT * FROM recordings WHERE endpoint='/v1/chat/completions';"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
**Hash mismatches:** Request parameters changed slightly between record and replay
|
||||||
|
```bash
|
||||||
|
# Compare request details
|
||||||
|
cat recordings/responses/abc123.json | jq '.request'
|
||||||
|
```
|
||||||
|
|
||||||
|
**Serialization errors:** Response types changed between versions
|
||||||
|
```bash
|
||||||
|
# Re-record with updated types
|
||||||
|
rm recordings/responses/failing_hash.json
|
||||||
|
LLAMA_STACK_TEST_INFERENCE_MODE=record pytest test_failing.py
|
||||||
|
```
|
||||||
|
|
||||||
|
**Missing recordings:** New test or changed parameters
|
||||||
|
```bash
|
||||||
|
# Record the missing interaction
|
||||||
|
LLAMA_STACK_TEST_INFERENCE_MODE=record pytest test_new.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Design Decisions
|
||||||
|
|
||||||
|
### Why Not Mocks?
|
||||||
|
|
||||||
|
Traditional mocking breaks down with AI APIs because:
|
||||||
|
- Response structures are complex and evolve frequently
|
||||||
|
- Streaming behavior is hard to mock correctly
|
||||||
|
- Edge cases in real APIs get missed
|
||||||
|
- Mocks become brittle maintenance burdens
|
||||||
|
|
||||||
|
### Why Precise Hashing?
|
||||||
|
|
||||||
|
Loose hashing (normalizing whitespace, rounding floats) seems convenient but hides bugs. If a test changes slightly, you want to know about it rather than accidentally getting the wrong cached response.
|
||||||
|
|
||||||
|
### Why JSON + SQLite?
|
||||||
|
|
||||||
|
- **JSON** - Human readable, diff-friendly, easy to inspect and modify
|
||||||
|
- **SQLite** - Fast indexed lookups without loading response bodies
|
||||||
|
- **Hybrid** - Best of both worlds for different use cases
|
||||||
|
|
||||||
|
This system provides reliable, fast testing against real AI APIs while maintaining the ability to debug issues when they arise.
|
||||||
156
docs/source/distributions/k8s-benchmark/README.md
Normal file
156
docs/source/distributions/k8s-benchmark/README.md
Normal file
|
|
@ -0,0 +1,156 @@
|
||||||
|
# Llama Stack Benchmark Suite on Kubernetes
|
||||||
|
|
||||||
|
## Motivation
|
||||||
|
|
||||||
|
Performance benchmarking is critical for understanding the overhead and characteristics of the Llama Stack abstraction layer compared to direct inference engines like vLLM.
|
||||||
|
|
||||||
|
### Why This Benchmark Suite Exists
|
||||||
|
|
||||||
|
**Performance Validation**: The Llama Stack provides a unified API layer across multiple inference providers, but this abstraction introduces potential overhead. This benchmark suite quantifies the performance impact by comparing:
|
||||||
|
- Llama Stack inference (with vLLM backend)
|
||||||
|
- Direct vLLM inference calls
|
||||||
|
- Both under identical Kubernetes deployment conditions
|
||||||
|
|
||||||
|
**Production Readiness Assessment**: Real-world deployments require understanding performance characteristics under load. This suite simulates concurrent user scenarios with configurable parameters (duration, concurrency, request patterns) to validate production readiness.
|
||||||
|
|
||||||
|
**Regression Detection (TODO)**: As the Llama Stack evolves, this benchmark provides automated regression detection for performance changes. CI/CD pipelines can leverage these benchmarks to catch performance degradations before production deployments.
|
||||||
|
|
||||||
|
**Resource Planning**: By measuring throughput, latency percentiles, and resource utilization patterns, teams can make informed decisions about:
|
||||||
|
- Kubernetes resource allocation (CPU, memory, GPU)
|
||||||
|
- Auto-scaling configurations
|
||||||
|
- Cost optimization strategies
|
||||||
|
|
||||||
|
### Key Metrics Captured
|
||||||
|
|
||||||
|
The benchmark suite measures critical performance indicators:
|
||||||
|
- **Throughput**: Requests per second under sustained load
|
||||||
|
- **Latency Distribution**: P50, P95, P99 response times
|
||||||
|
- **Time to First Token (TTFT)**: Critical for streaming applications
|
||||||
|
- **Error Rates**: Request failures and timeout analysis
|
||||||
|
|
||||||
|
This data enables data-driven architectural decisions and performance optimization efforts.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
**1. Deploy base k8s infrastructure:**
|
||||||
|
```bash
|
||||||
|
cd ../k8s
|
||||||
|
./apply.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Deploy benchmark components:**
|
||||||
|
```bash
|
||||||
|
cd ../k8s-benchmark
|
||||||
|
./apply.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
**3. Verify deployment:**
|
||||||
|
```bash
|
||||||
|
kubectl get pods
|
||||||
|
# Should see: llama-stack-benchmark-server, vllm-server, etc.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Basic Benchmarks
|
||||||
|
|
||||||
|
**Benchmark Llama Stack (default):**
|
||||||
|
```bash
|
||||||
|
cd docs/source/distributions/k8s-benchmark/
|
||||||
|
./run-benchmark.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
**Benchmark vLLM direct:**
|
||||||
|
```bash
|
||||||
|
./run-benchmark.sh --target vllm
|
||||||
|
```
|
||||||
|
|
||||||
|
### Custom Configuration
|
||||||
|
|
||||||
|
**Extended benchmark with high concurrency:**
|
||||||
|
```bash
|
||||||
|
./run-benchmark.sh --target vllm --duration 120 --concurrent 20
|
||||||
|
```
|
||||||
|
|
||||||
|
**Short test run:**
|
||||||
|
```bash
|
||||||
|
./run-benchmark.sh --target stack --duration 30 --concurrent 5
|
||||||
|
```
|
||||||
|
|
||||||
|
## Command Reference
|
||||||
|
|
||||||
|
### run-benchmark.sh Options
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./run-benchmark.sh [options]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
-t, --target <stack|vllm> Target to benchmark (default: stack)
|
||||||
|
-d, --duration <seconds> Duration in seconds (default: 60)
|
||||||
|
-c, --concurrent <users> Number of concurrent users (default: 10)
|
||||||
|
-h, --help Show help message
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
./run-benchmark.sh --target vllm # Benchmark vLLM direct
|
||||||
|
./run-benchmark.sh --target stack # Benchmark Llama Stack
|
||||||
|
./run-benchmark.sh -t vllm -d 120 -c 20 # vLLM with 120s, 20 users
|
||||||
|
```
|
||||||
|
|
||||||
|
## Local Testing
|
||||||
|
|
||||||
|
### Running Benchmark Locally
|
||||||
|
|
||||||
|
For local development without Kubernetes:
|
||||||
|
|
||||||
|
**1. Start OpenAI mock server:**
|
||||||
|
```bash
|
||||||
|
uv run python openai-mock-server.py --port 8080
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Run benchmark against mock server:**
|
||||||
|
```bash
|
||||||
|
uv run python benchmark.py \
|
||||||
|
--base-url http://localhost:8080/v1 \
|
||||||
|
--model mock-inference \
|
||||||
|
--duration 30 \
|
||||||
|
--concurrent 5
|
||||||
|
```
|
||||||
|
|
||||||
|
**3. Test against local vLLM server:**
|
||||||
|
```bash
|
||||||
|
# If you have vLLM running locally on port 8000
|
||||||
|
uv run python benchmark.py \
|
||||||
|
--base-url http://localhost:8000/v1 \
|
||||||
|
--model meta-llama/Llama-3.2-3B-Instruct \
|
||||||
|
--duration 30 \
|
||||||
|
--concurrent 5
|
||||||
|
```
|
||||||
|
|
||||||
|
**4. Profile the running server:**
|
||||||
|
```bash
|
||||||
|
./profile_running_server.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### OpenAI Mock Server
|
||||||
|
|
||||||
|
The `openai-mock-server.py` provides:
|
||||||
|
- **OpenAI-compatible API** for testing without real models
|
||||||
|
- **Configurable streaming delay** via `STREAM_DELAY_SECONDS` env var
|
||||||
|
- **Consistent responses** for reproducible benchmarks
|
||||||
|
- **Lightweight testing** without GPU requirements
|
||||||
|
|
||||||
|
**Mock server usage:**
|
||||||
|
```bash
|
||||||
|
uv run python openai-mock-server.py --port 8080
|
||||||
|
```
|
||||||
|
|
||||||
|
The mock server is also deployed in k8s as `openai-mock-service:8080` and can be used by changing the Llama Stack configuration to use the `mock-vllm-inference` provider.
|
||||||
|
|
||||||
|
## Files in this Directory
|
||||||
|
|
||||||
|
- `benchmark.py` - Core benchmark script with async streaming support
|
||||||
|
- `run-benchmark.sh` - Main script with target selection and configuration
|
||||||
|
- `openai-mock-server.py` - Mock OpenAI API server for local testing
|
||||||
|
- `README.md` - This documentation file
|
||||||
|
|
@ -8,7 +8,6 @@
|
||||||
|
|
||||||
# Deploys the benchmark-specific components on top of the base k8s deployment (../k8s/apply.sh).
|
# Deploys the benchmark-specific components on top of the base k8s deployment (../k8s/apply.sh).
|
||||||
|
|
||||||
export MOCK_INFERENCE_PORT=8080
|
|
||||||
export STREAM_DELAY_SECONDS=0.005
|
export STREAM_DELAY_SECONDS=0.005
|
||||||
|
|
||||||
export POSTGRES_USER=llamastack
|
export POSTGRES_USER=llamastack
|
||||||
|
|
@ -20,14 +19,7 @@ export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||||
|
|
||||||
export MOCK_INFERENCE_MODEL=mock-inference
|
export MOCK_INFERENCE_MODEL=mock-inference
|
||||||
|
|
||||||
# Use llama-stack-benchmark-service as the benchmark server
|
export MOCK_INFERENCE_URL=openai-mock-service:8080
|
||||||
export LOCUST_HOST=http://llama-stack-benchmark-service:8323
|
|
||||||
export LOCUST_BASE_PATH=/v1/openai/v1
|
|
||||||
|
|
||||||
# Use vllm-service as the benchmark server
|
|
||||||
# export LOCUST_HOST=http://vllm-server:8000
|
|
||||||
# export LOCUST_BASE_PATH=/v1
|
|
||||||
|
|
||||||
|
|
||||||
export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL
|
export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL
|
||||||
|
|
||||||
|
|
@ -35,13 +27,6 @@ set -euo pipefail
|
||||||
set -x
|
set -x
|
||||||
|
|
||||||
# Deploy benchmark-specific components
|
# Deploy benchmark-specific components
|
||||||
# Deploy OpenAI mock server
|
|
||||||
kubectl create configmap openai-mock --from-file=openai-mock-server.py \
|
|
||||||
--dry-run=client -o yaml | kubectl apply --validate=false -f -
|
|
||||||
|
|
||||||
envsubst < openai-mock-deployment.yaml | kubectl apply --validate=false -f -
|
|
||||||
|
|
||||||
# Create configmap with our custom stack config
|
|
||||||
kubectl create configmap llama-stack-config --from-file=stack_run_config.yaml \
|
kubectl create configmap llama-stack-config --from-file=stack_run_config.yaml \
|
||||||
--dry-run=client -o yaml > stack-configmap.yaml
|
--dry-run=client -o yaml > stack-configmap.yaml
|
||||||
|
|
||||||
|
|
@ -49,9 +34,3 @@ kubectl apply --validate=false -f stack-configmap.yaml
|
||||||
|
|
||||||
# Deploy our custom llama stack server (overriding the base one)
|
# Deploy our custom llama stack server (overriding the base one)
|
||||||
envsubst < stack-k8s.yaml.template | kubectl apply --validate=false -f -
|
envsubst < stack-k8s.yaml.template | kubectl apply --validate=false -f -
|
||||||
|
|
||||||
# Deploy Locust load testing
|
|
||||||
kubectl create configmap locust-script --from-file=locustfile.py \
|
|
||||||
--dry-run=client -o yaml | kubectl apply --validate=false -f -
|
|
||||||
|
|
||||||
envsubst < locust-k8s.yaml | kubectl apply --validate=false -f -
|
|
||||||
|
|
|
||||||
268
docs/source/distributions/k8s-benchmark/benchmark.py
Normal file
268
docs/source/distributions/k8s-benchmark/benchmark.py
Normal file
|
|
@ -0,0 +1,268 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Simple benchmark script for Llama Stack with OpenAI API compatibility.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import statistics
|
||||||
|
import time
|
||||||
|
from typing import Tuple
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkStats:
|
||||||
|
def __init__(self):
|
||||||
|
self.response_times = []
|
||||||
|
self.ttft_times = []
|
||||||
|
self.chunks_received = []
|
||||||
|
self.errors = []
|
||||||
|
self.success_count = 0
|
||||||
|
self.total_requests = 0
|
||||||
|
self.concurrent_users = 0
|
||||||
|
self.start_time = None
|
||||||
|
self.end_time = None
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def add_result(self, response_time: float, chunks: int, ttft: float = None, error: str = None):
|
||||||
|
async with self._lock:
|
||||||
|
self.total_requests += 1
|
||||||
|
if error:
|
||||||
|
self.errors.append(error)
|
||||||
|
else:
|
||||||
|
self.success_count += 1
|
||||||
|
self.response_times.append(response_time)
|
||||||
|
self.chunks_received.append(chunks)
|
||||||
|
if ttft is not None:
|
||||||
|
self.ttft_times.append(ttft)
|
||||||
|
|
||||||
|
def print_summary(self):
|
||||||
|
if not self.response_times:
|
||||||
|
print("No successful requests to report")
|
||||||
|
if self.errors:
|
||||||
|
print(f"Total errors: {len(self.errors)}")
|
||||||
|
print("First 5 errors:")
|
||||||
|
for error in self.errors[:5]:
|
||||||
|
print(f" {error}")
|
||||||
|
return
|
||||||
|
|
||||||
|
total_time = self.end_time - self.start_time
|
||||||
|
success_rate = (self.success_count / self.total_requests) * 100
|
||||||
|
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"BENCHMARK RESULTS")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
print(f"Total time: {total_time:.2f}s")
|
||||||
|
print(f"Concurrent users: {self.concurrent_users}")
|
||||||
|
print(f"Total requests: {self.total_requests}")
|
||||||
|
print(f"Successful requests: {self.success_count}")
|
||||||
|
print(f"Failed requests: {len(self.errors)}")
|
||||||
|
print(f"Success rate: {success_rate:.1f}%")
|
||||||
|
print(f"Requests per second: {self.success_count / total_time:.2f}")
|
||||||
|
|
||||||
|
print(f"\nResponse Time Statistics:")
|
||||||
|
print(f" Mean: {statistics.mean(self.response_times):.3f}s")
|
||||||
|
print(f" Median: {statistics.median(self.response_times):.3f}s")
|
||||||
|
print(f" Min: {min(self.response_times):.3f}s")
|
||||||
|
print(f" Max: {max(self.response_times):.3f}s")
|
||||||
|
|
||||||
|
if len(self.response_times) > 1:
|
||||||
|
print(f" Std Dev: {statistics.stdev(self.response_times):.3f}s")
|
||||||
|
|
||||||
|
percentiles = [50, 90, 95, 99]
|
||||||
|
sorted_times = sorted(self.response_times)
|
||||||
|
print(f"\nPercentiles:")
|
||||||
|
for p in percentiles:
|
||||||
|
idx = int(len(sorted_times) * p / 100) - 1
|
||||||
|
idx = max(0, min(idx, len(sorted_times) - 1))
|
||||||
|
print(f" P{p}: {sorted_times[idx]:.3f}s")
|
||||||
|
|
||||||
|
if self.ttft_times:
|
||||||
|
print(f"\nTime to First Token (TTFT) Statistics:")
|
||||||
|
print(f" Mean: {statistics.mean(self.ttft_times):.3f}s")
|
||||||
|
print(f" Median: {statistics.median(self.ttft_times):.3f}s")
|
||||||
|
print(f" Min: {min(self.ttft_times):.3f}s")
|
||||||
|
print(f" Max: {max(self.ttft_times):.3f}s")
|
||||||
|
|
||||||
|
if len(self.ttft_times) > 1:
|
||||||
|
print(f" Std Dev: {statistics.stdev(self.ttft_times):.3f}s")
|
||||||
|
|
||||||
|
sorted_ttft = sorted(self.ttft_times)
|
||||||
|
print(f"\nTTFT Percentiles:")
|
||||||
|
for p in percentiles:
|
||||||
|
idx = int(len(sorted_ttft) * p / 100) - 1
|
||||||
|
idx = max(0, min(idx, len(sorted_ttft) - 1))
|
||||||
|
print(f" P{p}: {sorted_ttft[idx]:.3f}s")
|
||||||
|
|
||||||
|
if self.chunks_received:
|
||||||
|
print(f"\nStreaming Statistics:")
|
||||||
|
print(f" Mean chunks per response: {statistics.mean(self.chunks_received):.1f}")
|
||||||
|
print(f" Total chunks received: {sum(self.chunks_received)}")
|
||||||
|
|
||||||
|
if self.errors:
|
||||||
|
print(f"\nErrors (showing first 5):")
|
||||||
|
for error in self.errors[:5]:
|
||||||
|
print(f" {error}")
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaStackBenchmark:
|
||||||
|
def __init__(self, base_url: str, model_id: str):
|
||||||
|
self.base_url = base_url.rstrip('/')
|
||||||
|
self.model_id = model_id
|
||||||
|
self.headers = {"Content-Type": "application/json"}
|
||||||
|
self.test_messages = [
|
||||||
|
[{"role": "user", "content": "Hi"}],
|
||||||
|
[{"role": "user", "content": "What is the capital of France?"}],
|
||||||
|
[{"role": "user", "content": "Explain quantum physics in simple terms."}],
|
||||||
|
[{"role": "user", "content": "Write a short story about a robot learning to paint."}],
|
||||||
|
[
|
||||||
|
{"role": "user", "content": "What is machine learning?"},
|
||||||
|
{"role": "assistant", "content": "Machine learning is a subset of AI..."},
|
||||||
|
{"role": "user", "content": "Can you give me a practical example?"}
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def make_async_streaming_request(self) -> Tuple[float, int, float | None, str | None]:
|
||||||
|
"""Make a single async streaming chat completion request."""
|
||||||
|
messages = random.choice(self.test_messages)
|
||||||
|
payload = {
|
||||||
|
"model": self.model_id,
|
||||||
|
"messages": messages,
|
||||||
|
"stream": True,
|
||||||
|
"max_tokens": 100
|
||||||
|
}
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
chunks_received = 0
|
||||||
|
ttft = None
|
||||||
|
error = None
|
||||||
|
|
||||||
|
session = aiohttp.ClientSession()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/chat/completions",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
timeout=aiohttp.ClientTimeout(total=30)
|
||||||
|
) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
async for line in response.content:
|
||||||
|
if line:
|
||||||
|
line_str = line.decode('utf-8').strip()
|
||||||
|
if line_str.startswith('data: '):
|
||||||
|
chunks_received += 1
|
||||||
|
if ttft is None:
|
||||||
|
ttft = time.time() - start_time
|
||||||
|
if line_str == 'data: [DONE]':
|
||||||
|
break
|
||||||
|
|
||||||
|
if chunks_received == 0:
|
||||||
|
error = "No streaming chunks received"
|
||||||
|
else:
|
||||||
|
text = await response.text()
|
||||||
|
error = f"HTTP {response.status}: {text[:100]}"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error = f"Request error: {str(e)}"
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
response_time = time.time() - start_time
|
||||||
|
return response_time, chunks_received, ttft, error
|
||||||
|
|
||||||
|
|
||||||
|
async def run_benchmark(self, duration: int, concurrent_users: int) -> BenchmarkStats:
|
||||||
|
"""Run benchmark using async requests for specified duration."""
|
||||||
|
stats = BenchmarkStats()
|
||||||
|
stats.concurrent_users = concurrent_users
|
||||||
|
stats.start_time = time.time()
|
||||||
|
|
||||||
|
print(f"Starting benchmark: {duration}s duration, {concurrent_users} concurrent users")
|
||||||
|
print(f"Target URL: {self.base_url}/chat/completions")
|
||||||
|
print(f"Model: {self.model_id}")
|
||||||
|
|
||||||
|
connector = aiohttp.TCPConnector(limit=concurrent_users)
|
||||||
|
async with aiohttp.ClientSession(connector=connector) as session:
|
||||||
|
|
||||||
|
async def worker(worker_id: int):
|
||||||
|
"""Worker that sends requests sequentially until canceled."""
|
||||||
|
request_count = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
response_time, chunks, ttft, error = await self.make_async_streaming_request()
|
||||||
|
await stats.add_result(response_time, chunks, ttft, error)
|
||||||
|
request_count += 1
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
await stats.add_result(0, 0, None, f"Worker {worker_id} error: {str(e)}")
|
||||||
|
|
||||||
|
# Progress reporting task
|
||||||
|
async def progress_reporter():
|
||||||
|
last_report_time = time.time()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(1) # Report every second
|
||||||
|
if time.time() >= last_report_time + 10: # Report every 10 seconds
|
||||||
|
elapsed = time.time() - stats.start_time
|
||||||
|
print(f"Completed: {stats.total_requests} requests in {elapsed:.1f}s")
|
||||||
|
last_report_time = time.time()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Spawn concurrent workers
|
||||||
|
tasks = [asyncio.create_task(worker(i)) for i in range(concurrent_users)]
|
||||||
|
progress_task = asyncio.create_task(progress_reporter())
|
||||||
|
tasks.append(progress_task)
|
||||||
|
|
||||||
|
# Wait for duration then cancel all tasks
|
||||||
|
await asyncio.sleep(duration)
|
||||||
|
|
||||||
|
for task in tasks:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
# Wait for all tasks to complete
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
stats.end_time = time.time()
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Llama Stack Benchmark Tool")
|
||||||
|
parser.add_argument("--base-url", default=os.getenv("BENCHMARK_BASE_URL", "http://localhost:8000/v1/openai/v1"),
|
||||||
|
help="Base URL for the API (default: http://localhost:8000/v1/openai/v1)")
|
||||||
|
parser.add_argument("--model", default=os.getenv("INFERENCE_MODEL", "test-model"),
|
||||||
|
help="Model ID to use for requests")
|
||||||
|
parser.add_argument("--duration", type=int, default=60,
|
||||||
|
help="Duration in seconds to run benchmark (default: 60)")
|
||||||
|
parser.add_argument("--concurrent", type=int, default=10,
|
||||||
|
help="Number of concurrent users (default: 10)")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
benchmark = LlamaStackBenchmark(args.base_url, args.model)
|
||||||
|
|
||||||
|
try:
|
||||||
|
stats = asyncio.run(benchmark.run_benchmark(args.duration, args.concurrent))
|
||||||
|
stats.print_summary()
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nBenchmark interrupted by user")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Benchmark failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -1,131 +0,0 @@
|
||||||
apiVersion: apps/v1
|
|
||||||
kind: Deployment
|
|
||||||
metadata:
|
|
||||||
name: locust-master
|
|
||||||
labels:
|
|
||||||
app: locust
|
|
||||||
role: master
|
|
||||||
spec:
|
|
||||||
replicas: 1
|
|
||||||
selector:
|
|
||||||
matchLabels:
|
|
||||||
app: locust
|
|
||||||
role: master
|
|
||||||
template:
|
|
||||||
metadata:
|
|
||||||
labels:
|
|
||||||
app: locust
|
|
||||||
role: master
|
|
||||||
spec:
|
|
||||||
containers:
|
|
||||||
- name: locust-master
|
|
||||||
image: locustio/locust:2.31.8
|
|
||||||
ports:
|
|
||||||
- containerPort: 8089 # Web UI
|
|
||||||
- containerPort: 5557 # Master communication
|
|
||||||
env:
|
|
||||||
- name: LOCUST_HOST
|
|
||||||
value: "${LOCUST_HOST}"
|
|
||||||
- name: LOCUST_LOCUSTFILE
|
|
||||||
value: "/locust/locustfile.py"
|
|
||||||
- name: LOCUST_WEB_HOST
|
|
||||||
value: "0.0.0.0"
|
|
||||||
- name: LOCUST_MASTER
|
|
||||||
value: "true"
|
|
||||||
- name: LOCUST_BASE_PATH
|
|
||||||
value: "${LOCUST_BASE_PATH}"
|
|
||||||
- name: INFERENCE_MODEL
|
|
||||||
value: "${BENCHMARK_INFERENCE_MODEL}"
|
|
||||||
volumeMounts:
|
|
||||||
- name: locust-script
|
|
||||||
mountPath: /locust
|
|
||||||
command: ["locust"]
|
|
||||||
args:
|
|
||||||
- "--master"
|
|
||||||
- "--web-host=0.0.0.0"
|
|
||||||
- "--web-port=8089"
|
|
||||||
- "--host=${LOCUST_HOST}"
|
|
||||||
- "--locustfile=/locust/locustfile.py"
|
|
||||||
volumes:
|
|
||||||
- name: locust-script
|
|
||||||
configMap:
|
|
||||||
name: locust-script
|
|
||||||
---
|
|
||||||
apiVersion: apps/v1
|
|
||||||
kind: Deployment
|
|
||||||
metadata:
|
|
||||||
name: locust-worker
|
|
||||||
labels:
|
|
||||||
app: locust
|
|
||||||
role: worker
|
|
||||||
spec:
|
|
||||||
replicas: 2 # Start with 2 workers, can be scaled up
|
|
||||||
selector:
|
|
||||||
matchLabels:
|
|
||||||
app: locust
|
|
||||||
role: worker
|
|
||||||
template:
|
|
||||||
metadata:
|
|
||||||
labels:
|
|
||||||
app: locust
|
|
||||||
role: worker
|
|
||||||
spec:
|
|
||||||
containers:
|
|
||||||
- name: locust-worker
|
|
||||||
image: locustio/locust:2.31.8
|
|
||||||
env:
|
|
||||||
- name: LOCUST_HOST
|
|
||||||
value: "${LOCUST_HOST}"
|
|
||||||
- name: LOCUST_LOCUSTFILE
|
|
||||||
value: "/locust/locustfile.py"
|
|
||||||
- name: LOCUST_MASTER_HOST
|
|
||||||
value: "locust-master-service"
|
|
||||||
- name: LOCUST_MASTER_PORT
|
|
||||||
value: "5557"
|
|
||||||
- name: INFERENCE_MODEL
|
|
||||||
value: "${BENCHMARK_INFERENCE_MODEL}"
|
|
||||||
- name: LOCUST_BASE_PATH
|
|
||||||
value: "${LOCUST_BASE_PATH}"
|
|
||||||
volumeMounts:
|
|
||||||
- name: locust-script
|
|
||||||
mountPath: /locust
|
|
||||||
command: ["locust"]
|
|
||||||
args:
|
|
||||||
- "--worker"
|
|
||||||
- "--master-host=locust-master-service"
|
|
||||||
- "--master-port=5557"
|
|
||||||
- "--locustfile=/locust/locustfile.py"
|
|
||||||
volumes:
|
|
||||||
- name: locust-script
|
|
||||||
configMap:
|
|
||||||
name: locust-script
|
|
||||||
---
|
|
||||||
apiVersion: v1
|
|
||||||
kind: Service
|
|
||||||
metadata:
|
|
||||||
name: locust-master-service
|
|
||||||
spec:
|
|
||||||
selector:
|
|
||||||
app: locust
|
|
||||||
role: master
|
|
||||||
ports:
|
|
||||||
- name: web-ui
|
|
||||||
port: 8089
|
|
||||||
targetPort: 8089
|
|
||||||
- name: master-comm
|
|
||||||
port: 5557
|
|
||||||
targetPort: 5557
|
|
||||||
type: ClusterIP
|
|
||||||
---
|
|
||||||
apiVersion: v1
|
|
||||||
kind: Service
|
|
||||||
metadata:
|
|
||||||
name: locust-web-ui
|
|
||||||
spec:
|
|
||||||
selector:
|
|
||||||
app: locust
|
|
||||||
role: master
|
|
||||||
ports:
|
|
||||||
- port: 8089
|
|
||||||
targetPort: 8089
|
|
||||||
type: ClusterIP # Keep internal, use port-forward to access
|
|
||||||
|
|
@ -1,78 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
"""
|
|
||||||
Locust load testing script for Llama Stack with Prism mock OpenAI provider.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import random
|
|
||||||
from locust import HttpUser, task, between
|
|
||||||
import os
|
|
||||||
|
|
||||||
base_path = os.getenv("LOCUST_BASE_PATH", "/v1/openai/v1")
|
|
||||||
|
|
||||||
MODEL_ID = os.getenv("INFERENCE_MODEL")
|
|
||||||
|
|
||||||
class LlamaStackUser(HttpUser):
|
|
||||||
wait_time = between(0.0, 0.0001)
|
|
||||||
|
|
||||||
def on_start(self):
|
|
||||||
"""Setup authentication and test data."""
|
|
||||||
# No auth required for benchmark server
|
|
||||||
self.headers = {
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Test messages of varying lengths
|
|
||||||
self.test_messages = [
|
|
||||||
[{"role": "user", "content": "Hi"}],
|
|
||||||
[{"role": "user", "content": "What is the capital of France?"}],
|
|
||||||
[{"role": "user", "content": "Explain quantum physics in simple terms."}],
|
|
||||||
[{"role": "user", "content": "Write a short story about a robot learning to paint."}],
|
|
||||||
[
|
|
||||||
{"role": "user", "content": "What is machine learning?"},
|
|
||||||
{"role": "assistant", "content": "Machine learning is a subset of AI..."},
|
|
||||||
{"role": "user", "content": "Can you give me a practical example?"}
|
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
@task(weight=100)
|
|
||||||
def chat_completion_streaming(self):
|
|
||||||
"""Test streaming chat completion (20% of requests)."""
|
|
||||||
messages = random.choice(self.test_messages)
|
|
||||||
payload = {
|
|
||||||
"model": MODEL_ID,
|
|
||||||
"messages": messages,
|
|
||||||
"stream": True,
|
|
||||||
"max_tokens": 100
|
|
||||||
}
|
|
||||||
|
|
||||||
with self.client.post(
|
|
||||||
f"{base_path}/chat/completions",
|
|
||||||
headers=self.headers,
|
|
||||||
json=payload,
|
|
||||||
stream=True,
|
|
||||||
catch_response=True
|
|
||||||
) as response:
|
|
||||||
if response.status_code == 200:
|
|
||||||
chunks_received = 0
|
|
||||||
try:
|
|
||||||
for line in response.iter_lines():
|
|
||||||
if line:
|
|
||||||
line_str = line.decode('utf-8')
|
|
||||||
if line_str.startswith('data: '):
|
|
||||||
chunks_received += 1
|
|
||||||
if line_str.strip() == 'data: [DONE]':
|
|
||||||
break
|
|
||||||
|
|
||||||
if chunks_received > 0:
|
|
||||||
response.success()
|
|
||||||
else:
|
|
||||||
response.failure("No streaming chunks received")
|
|
||||||
except Exception as e:
|
|
||||||
response.failure(f"Streaming error: {e}")
|
|
||||||
else:
|
|
||||||
response.failure(f"HTTP {response.status_code}: {response.text}")
|
|
||||||
|
|
@ -1,52 +0,0 @@
|
||||||
apiVersion: apps/v1
|
|
||||||
kind: Deployment
|
|
||||||
metadata:
|
|
||||||
name: openai-mock
|
|
||||||
labels:
|
|
||||||
app: openai-mock
|
|
||||||
spec:
|
|
||||||
replicas: 1
|
|
||||||
selector:
|
|
||||||
matchLabels:
|
|
||||||
app: openai-mock
|
|
||||||
template:
|
|
||||||
metadata:
|
|
||||||
labels:
|
|
||||||
app: openai-mock
|
|
||||||
spec:
|
|
||||||
containers:
|
|
||||||
- name: openai-mock
|
|
||||||
image: python:3.12-slim
|
|
||||||
ports:
|
|
||||||
- containerPort: ${MOCK_INFERENCE_PORT}
|
|
||||||
env:
|
|
||||||
- name: PORT
|
|
||||||
value: "${MOCK_INFERENCE_PORT}"
|
|
||||||
- name: MOCK_MODELS
|
|
||||||
value: "${MOCK_INFERENCE_MODEL}"
|
|
||||||
- name: STREAM_DELAY_SECONDS
|
|
||||||
value: "${STREAM_DELAY_SECONDS}"
|
|
||||||
command: ["sh", "-c"]
|
|
||||||
args:
|
|
||||||
- |
|
|
||||||
pip install flask &&
|
|
||||||
python /app/openai-mock-server.py --port ${MOCK_INFERENCE_PORT}
|
|
||||||
volumeMounts:
|
|
||||||
- name: openai-mock-script
|
|
||||||
mountPath: /app
|
|
||||||
volumes:
|
|
||||||
- name: openai-mock-script
|
|
||||||
configMap:
|
|
||||||
name: openai-mock
|
|
||||||
---
|
|
||||||
apiVersion: v1
|
|
||||||
kind: Service
|
|
||||||
metadata:
|
|
||||||
name: openai-mock-service
|
|
||||||
spec:
|
|
||||||
selector:
|
|
||||||
app: openai-mock
|
|
||||||
ports:
|
|
||||||
- port: 8080
|
|
||||||
targetPort: 8080
|
|
||||||
type: ClusterIP
|
|
||||||
6
docs/source/distributions/k8s-benchmark/openai-mock-server.py
Normal file → Executable file
6
docs/source/distributions/k8s-benchmark/openai-mock-server.py
Normal file → Executable file
|
|
@ -23,7 +23,7 @@ app = Flask(__name__)
|
||||||
|
|
||||||
# Models from environment variables
|
# Models from environment variables
|
||||||
def get_models():
|
def get_models():
|
||||||
models_str = os.getenv("MOCK_MODELS", "mock-inference")
|
models_str = os.getenv("MOCK_MODELS", "meta-llama/Llama-3.2-3B-Instruct")
|
||||||
model_ids = [m.strip() for m in models_str.split(",") if m.strip()]
|
model_ids = [m.strip() for m in models_str.split(",") if m.strip()]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -49,13 +49,13 @@ def generate_random_text(length=50):
|
||||||
]
|
]
|
||||||
return " ".join(random.choices(words, k=length))
|
return " ".join(random.choices(words, k=length))
|
||||||
|
|
||||||
@app.route('/models', methods=['GET'])
|
@app.route('/v1/models', methods=['GET'])
|
||||||
def list_models():
|
def list_models():
|
||||||
models = get_models()
|
models = get_models()
|
||||||
print(f"[MOCK] Returning models: {[m['id'] for m in models['data']]}")
|
print(f"[MOCK] Returning models: {[m['id'] for m in models['data']]}")
|
||||||
return jsonify(models)
|
return jsonify(models)
|
||||||
|
|
||||||
@app.route('/chat/completions', methods=['POST'])
|
@app.route('/v1/chat/completions', methods=['POST'])
|
||||||
def chat_completions():
|
def chat_completions():
|
||||||
"""Return OpenAI-formatted chat completion responses."""
|
"""Return OpenAI-formatted chat completion responses."""
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
|
|
|
||||||
52
docs/source/distributions/k8s-benchmark/profile_running_server.sh
Executable file
52
docs/source/distributions/k8s-benchmark/profile_running_server.sh
Executable file
|
|
@ -0,0 +1,52 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Script to profile an already running Llama Stack server
|
||||||
|
# Usage: ./profile_running_server.sh [duration_seconds] [output_file]
|
||||||
|
|
||||||
|
DURATION=${1:-60} # Default 60 seconds
|
||||||
|
OUTPUT_FILE=${2:-"llama_stack_profile"} # Default output file
|
||||||
|
|
||||||
|
echo "Looking for running Llama Stack server..."
|
||||||
|
|
||||||
|
# Find the server PID
|
||||||
|
SERVER_PID=$(ps aux | grep "llama_stack.core.server.server" | grep -v grep | awk '{print $2}' | head -1)
|
||||||
|
|
||||||
|
|
||||||
|
if [ -z "$SERVER_PID" ]; then
|
||||||
|
echo "Error: No running Llama Stack server found"
|
||||||
|
echo "Please start your server first with:"
|
||||||
|
echo "LLAMA_STACK_LOGGING=\"all=ERROR\" MOCK_INFERENCE_URL=http://localhost:8080 SAFETY_MODEL=llama-guard3:1b uv run --with llama-stack python -m llama_stack.core.server.server docs/source/distributions/k8s-benchmark/stack_run_config.yaml"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Found Llama Stack server with PID: $SERVER_PID"
|
||||||
|
|
||||||
|
# Start py-spy profiling
|
||||||
|
echo "Starting py-spy profiling for ${DURATION} seconds..."
|
||||||
|
echo "Output will be saved to: ${OUTPUT_FILE}.svg"
|
||||||
|
echo ""
|
||||||
|
echo "You can now run your load test..."
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Get the full path to py-spy
|
||||||
|
PYSPY_PATH=$(which py-spy)
|
||||||
|
|
||||||
|
# Check if running as root, if not, use sudo
|
||||||
|
if [ "$EUID" -ne 0 ]; then
|
||||||
|
echo "py-spy requires root permissions on macOS. Running with sudo..."
|
||||||
|
sudo "$PYSPY_PATH" record -o "${OUTPUT_FILE}.svg" -d ${DURATION} -p $SERVER_PID
|
||||||
|
else
|
||||||
|
"$PYSPY_PATH" record -o "${OUTPUT_FILE}.svg" -d ${DURATION} -p $SERVER_PID
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Profiling completed! Results saved to: ${OUTPUT_FILE}.svg"
|
||||||
|
echo ""
|
||||||
|
echo "To view the flame graph:"
|
||||||
|
echo "open ${OUTPUT_FILE}.svg"
|
||||||
148
docs/source/distributions/k8s-benchmark/run-benchmark.sh
Executable file
148
docs/source/distributions/k8s-benchmark/run-benchmark.sh
Executable file
|
|
@ -0,0 +1,148 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# Default values
|
||||||
|
TARGET="stack"
|
||||||
|
DURATION=60
|
||||||
|
CONCURRENT=10
|
||||||
|
|
||||||
|
# Parse command line arguments
|
||||||
|
usage() {
|
||||||
|
echo "Usage: $0 [options]"
|
||||||
|
echo "Options:"
|
||||||
|
echo " -t, --target <stack|vllm> Target to benchmark (default: stack)"
|
||||||
|
echo " -d, --duration <seconds> Duration in seconds (default: 60)"
|
||||||
|
echo " -c, --concurrent <users> Number of concurrent users (default: 10)"
|
||||||
|
echo " -h, --help Show this help message"
|
||||||
|
echo ""
|
||||||
|
echo "Examples:"
|
||||||
|
echo " $0 --target vllm # Benchmark vLLM direct"
|
||||||
|
echo " $0 --target stack # Benchmark Llama Stack (default)"
|
||||||
|
echo " $0 -t vllm -d 120 -c 20 # vLLM with 120s duration, 20 users"
|
||||||
|
}
|
||||||
|
|
||||||
|
while [[ $# -gt 0 ]]; do
|
||||||
|
case $1 in
|
||||||
|
-t|--target)
|
||||||
|
TARGET="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-d|--duration)
|
||||||
|
DURATION="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-c|--concurrent)
|
||||||
|
CONCURRENT="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-h|--help)
|
||||||
|
usage
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Unknown option: $1"
|
||||||
|
usage
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
# Validate target
|
||||||
|
if [[ "$TARGET" != "stack" && "$TARGET" != "vllm" ]]; then
|
||||||
|
echo "Error: Target must be 'stack' or 'vllm'"
|
||||||
|
usage
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Set configuration based on target
|
||||||
|
if [[ "$TARGET" == "vllm" ]]; then
|
||||||
|
BASE_URL="http://vllm-server:8000/v1"
|
||||||
|
JOB_NAME="vllm-benchmark-job"
|
||||||
|
echo "Benchmarking vLLM direct..."
|
||||||
|
else
|
||||||
|
BASE_URL="http://llama-stack-benchmark-service:8323/v1/openai/v1"
|
||||||
|
JOB_NAME="stack-benchmark-job"
|
||||||
|
echo "Benchmarking Llama Stack..."
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Configuration:"
|
||||||
|
echo " Target: $TARGET"
|
||||||
|
echo " Base URL: $BASE_URL"
|
||||||
|
echo " Duration: ${DURATION}s"
|
||||||
|
echo " Concurrent users: $CONCURRENT"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Create temporary job yaml
|
||||||
|
TEMP_YAML="/tmp/benchmark-job-temp-$(date +%s).yaml"
|
||||||
|
cat > "$TEMP_YAML" << EOF
|
||||||
|
apiVersion: batch/v1
|
||||||
|
kind: Job
|
||||||
|
metadata:
|
||||||
|
name: $JOB_NAME
|
||||||
|
namespace: default
|
||||||
|
spec:
|
||||||
|
template:
|
||||||
|
spec:
|
||||||
|
containers:
|
||||||
|
- name: benchmark
|
||||||
|
image: python:3.11-slim
|
||||||
|
command: ["/bin/bash"]
|
||||||
|
args:
|
||||||
|
- "-c"
|
||||||
|
- |
|
||||||
|
pip install aiohttp &&
|
||||||
|
python3 /benchmark/benchmark.py \\
|
||||||
|
--base-url $BASE_URL \\
|
||||||
|
--model \${INFERENCE_MODEL} \\
|
||||||
|
--duration $DURATION \\
|
||||||
|
--concurrent $CONCURRENT
|
||||||
|
env:
|
||||||
|
- name: INFERENCE_MODEL
|
||||||
|
value: "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
|
volumeMounts:
|
||||||
|
- name: benchmark-script
|
||||||
|
mountPath: /benchmark
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
memory: "256Mi"
|
||||||
|
cpu: "250m"
|
||||||
|
limits:
|
||||||
|
memory: "512Mi"
|
||||||
|
cpu: "500m"
|
||||||
|
volumes:
|
||||||
|
- name: benchmark-script
|
||||||
|
configMap:
|
||||||
|
name: benchmark-script
|
||||||
|
restartPolicy: Never
|
||||||
|
backoffLimit: 3
|
||||||
|
EOF
|
||||||
|
|
||||||
|
echo "Creating benchmark ConfigMap..."
|
||||||
|
kubectl create configmap benchmark-script \
|
||||||
|
--from-file=benchmark.py=benchmark.py \
|
||||||
|
--dry-run=client -o yaml | kubectl apply -f -
|
||||||
|
|
||||||
|
echo "Cleaning up any existing benchmark job..."
|
||||||
|
kubectl delete job $JOB_NAME 2>/dev/null || true
|
||||||
|
|
||||||
|
echo "Deploying benchmark Job..."
|
||||||
|
kubectl apply -f "$TEMP_YAML"
|
||||||
|
|
||||||
|
echo "Waiting for job to start..."
|
||||||
|
kubectl wait --for=condition=Ready pod -l job-name=$JOB_NAME --timeout=60s
|
||||||
|
|
||||||
|
echo "Following benchmark logs..."
|
||||||
|
kubectl logs -f job/$JOB_NAME
|
||||||
|
|
||||||
|
echo "Job completed. Checking final status..."
|
||||||
|
kubectl get job $JOB_NAME
|
||||||
|
|
||||||
|
# Clean up temporary file
|
||||||
|
rm -f "$TEMP_YAML"
|
||||||
|
|
@ -26,13 +26,6 @@ data:
|
||||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||||
- provider_id: mock-vllm-inference
|
|
||||||
provider_type: remote::vllm
|
|
||||||
config:
|
|
||||||
url: http://openai-mock-service:${env.MOCK_INFERENCE_PORT}
|
|
||||||
max_tokens: 4096
|
|
||||||
api_token: fake
|
|
||||||
tls_verify: false
|
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
provider_type: inline::sentence-transformers
|
provider_type: inline::sentence-transformers
|
||||||
config: {}
|
config: {}
|
||||||
|
|
@ -121,9 +114,6 @@ data:
|
||||||
- model_id: ${env.SAFETY_MODEL}
|
- model_id: ${env.SAFETY_MODEL}
|
||||||
provider_id: vllm-safety
|
provider_id: vllm-safety
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- model_id: ${env.MOCK_INFERENCE_MODEL}
|
|
||||||
provider_id: mock-vllm-inference
|
|
||||||
model_type: llm
|
|
||||||
shields:
|
shields:
|
||||||
- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
|
- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
|
|
|
||||||
|
|
@ -44,8 +44,6 @@ spec:
|
||||||
value: "${SAFETY_MODEL}"
|
value: "${SAFETY_MODEL}"
|
||||||
- name: TAVILY_SEARCH_API_KEY
|
- name: TAVILY_SEARCH_API_KEY
|
||||||
value: "${TAVILY_SEARCH_API_KEY}"
|
value: "${TAVILY_SEARCH_API_KEY}"
|
||||||
- name: MOCK_INFERENCE_PORT
|
|
||||||
value: "${MOCK_INFERENCE_PORT}"
|
|
||||||
- name: VLLM_URL
|
- name: VLLM_URL
|
||||||
value: http://vllm-server.default.svc.cluster.local:8000/v1
|
value: http://vllm-server.default.svc.cluster.local:8000/v1
|
||||||
- name: VLLM_MAX_TOKENS
|
- name: VLLM_MAX_TOKENS
|
||||||
|
|
@ -54,8 +52,6 @@ spec:
|
||||||
value: http://vllm-server-safety.default.svc.cluster.local:8001/v1
|
value: http://vllm-server-safety.default.svc.cluster.local:8001/v1
|
||||||
- name: VLLM_TLS_VERIFY
|
- name: VLLM_TLS_VERIFY
|
||||||
value: "false"
|
value: "false"
|
||||||
- name: MOCK_INFERENCE_MODEL
|
|
||||||
value: "${MOCK_INFERENCE_MODEL}"
|
|
||||||
command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8323"]
|
command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8323"]
|
||||||
ports:
|
ports:
|
||||||
- containerPort: 8323
|
- containerPort: 8323
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@ image_name: kubernetes-benchmark-demo
|
||||||
apis:
|
apis:
|
||||||
- agents
|
- agents
|
||||||
- inference
|
- inference
|
||||||
- safety
|
|
||||||
- telemetry
|
- telemetry
|
||||||
- tool_runtime
|
- tool_runtime
|
||||||
- vector_io
|
- vector_io
|
||||||
|
|
@ -16,20 +15,6 @@ providers:
|
||||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||||
- provider_id: vllm-safety
|
|
||||||
provider_type: remote::vllm
|
|
||||||
config:
|
|
||||||
url: ${env.VLLM_SAFETY_URL:=http://localhost:8000/v1}
|
|
||||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
|
||||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
|
||||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
|
||||||
- provider_id: mock-vllm-inference
|
|
||||||
provider_type: remote::vllm
|
|
||||||
config:
|
|
||||||
url: http://openai-mock-service:${env.MOCK_INFERENCE_PORT}
|
|
||||||
max_tokens: 4096
|
|
||||||
api_token: fake
|
|
||||||
tls_verify: false
|
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
provider_type: inline::sentence-transformers
|
provider_type: inline::sentence-transformers
|
||||||
config: {}
|
config: {}
|
||||||
|
|
@ -45,11 +30,6 @@ providers:
|
||||||
db: ${env.POSTGRES_DB:=llamastack}
|
db: ${env.POSTGRES_DB:=llamastack}
|
||||||
user: ${env.POSTGRES_USER:=llamastack}
|
user: ${env.POSTGRES_USER:=llamastack}
|
||||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||||
safety:
|
|
||||||
- provider_id: llama-guard
|
|
||||||
provider_type: inline::llama-guard
|
|
||||||
config:
|
|
||||||
excluded_categories: []
|
|
||||||
agents:
|
agents:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
|
|
@ -115,14 +95,6 @@ models:
|
||||||
- model_id: ${env.INFERENCE_MODEL}
|
- model_id: ${env.INFERENCE_MODEL}
|
||||||
provider_id: vllm-inference
|
provider_id: vllm-inference
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- model_id: ${env.SAFETY_MODEL}
|
|
||||||
provider_id: vllm-safety
|
|
||||||
model_type: llm
|
|
||||||
- model_id: ${env.MOCK_INFERENCE_MODEL}
|
|
||||||
provider_id: mock-vllm-inference
|
|
||||||
model_type: llm
|
|
||||||
shields:
|
|
||||||
- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
|
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
datasets: []
|
datasets: []
|
||||||
scoring_fns: []
|
scoring_fns: []
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,15 @@
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
|
Agents API for creating and interacting with agentic systems.
|
||||||
|
|
||||||
|
Main functionalities provided by this API:
|
||||||
|
- Create agents with specific instructions and ability to use tools.
|
||||||
|
- Interactions with agents are grouped into sessions ("threads"), and each interaction is called a "turn".
|
||||||
|
- Agents can be provided with various tools (see the ToolGroups and ToolRuntime APIs for more details).
|
||||||
|
- Agents can be provided with various shields (see the Safety API for more details).
|
||||||
|
- Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details.
|
||||||
|
|
||||||
This section contains documentation for all available providers for the **agents** API.
|
This section contains documentation for all available providers for the **agents** API.
|
||||||
|
|
||||||
## Providers
|
## Providers
|
||||||
|
|
|
||||||
21
docs/source/providers/batches/index.md
Normal file
21
docs/source/providers/batches/index.md
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
# Batches
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Protocol for batch processing API operations.
|
||||||
|
|
||||||
|
The Batches API enables efficient processing of multiple requests in a single operation,
|
||||||
|
particularly useful for processing large datasets, batch evaluation workflows, and
|
||||||
|
cost-effective inference at scale.
|
||||||
|
|
||||||
|
Note: This API is currently under active development and may undergo changes.
|
||||||
|
|
||||||
|
This section contains documentation for all available providers for the **batches** API.
|
||||||
|
|
||||||
|
## Providers
|
||||||
|
|
||||||
|
```{toctree}
|
||||||
|
:maxdepth: 1
|
||||||
|
|
||||||
|
inline_reference
|
||||||
|
```
|
||||||
23
docs/source/providers/batches/inline_reference.md
Normal file
23
docs/source/providers/batches/inline_reference.md
Normal file
|
|
@ -0,0 +1,23 @@
|
||||||
|
# inline::reference
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
Reference implementation of batches API with KVStore persistence.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
| Field | Type | Required | Default | Description |
|
||||||
|
|-------|------|----------|---------|-------------|
|
||||||
|
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Configuration for the key-value store backend. |
|
||||||
|
| `max_concurrent_batches` | `<class 'int'>` | No | 1 | Maximum number of concurrent batches to process simultaneously. |
|
||||||
|
| `max_concurrent_requests_per_batch` | `<class 'int'>` | No | 10 | Maximum number of concurrent requests to process per batch. |
|
||||||
|
|
||||||
|
## Sample Configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/batches.db
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
|
@ -2,6 +2,8 @@
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
|
Llama Stack Evaluation API for running evaluations on model and agent candidates.
|
||||||
|
|
||||||
This section contains documentation for all available providers for the **eval** API.
|
This section contains documentation for all available providers for the **eval** API.
|
||||||
|
|
||||||
## Providers
|
## Providers
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,12 @@
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
|
Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||||
|
|
||||||
|
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
||||||
|
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
||||||
|
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||||
|
|
||||||
This section contains documentation for all available providers for the **inference** API.
|
This section contains documentation for all available providers for the **inference** API.
|
||||||
|
|
||||||
## Providers
|
## Providers
|
||||||
|
|
|
||||||
9
llama_stack/apis/batches/__init__.py
Normal file
9
llama_stack/apis/batches/__init__.py
Normal file
|
|
@ -0,0 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .batches import Batches, BatchObject, ListBatchesResponse
|
||||||
|
|
||||||
|
__all__ = ["Batches", "BatchObject", "ListBatchesResponse"]
|
||||||
89
llama_stack/apis/batches/batches.py
Normal file
89
llama_stack/apis/batches/batches.py
Normal file
|
|
@ -0,0 +1,89 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
try:
|
||||||
|
from openai.types import Batch as BatchObject
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ListBatchesResponse(BaseModel):
|
||||||
|
"""Response containing a list of batch objects."""
|
||||||
|
|
||||||
|
object: Literal["list"] = "list"
|
||||||
|
data: list[BatchObject] = Field(..., description="List of batch objects")
|
||||||
|
first_id: str | None = Field(default=None, description="ID of the first batch in the list")
|
||||||
|
last_id: str | None = Field(default=None, description="ID of the last batch in the list")
|
||||||
|
has_more: bool = Field(default=False, description="Whether there are more batches available")
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class Batches(Protocol):
|
||||||
|
"""Protocol for batch processing API operations.
|
||||||
|
|
||||||
|
The Batches API enables efficient processing of multiple requests in a single operation,
|
||||||
|
particularly useful for processing large datasets, batch evaluation workflows, and
|
||||||
|
cost-effective inference at scale.
|
||||||
|
|
||||||
|
Note: This API is currently under active development and may undergo changes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@webmethod(route="/openai/v1/batches", method="POST")
|
||||||
|
async def create_batch(
|
||||||
|
self,
|
||||||
|
input_file_id: str,
|
||||||
|
endpoint: str,
|
||||||
|
completion_window: Literal["24h"],
|
||||||
|
metadata: dict[str, str] | None = None,
|
||||||
|
) -> BatchObject:
|
||||||
|
"""Create a new batch for processing multiple API requests.
|
||||||
|
|
||||||
|
:param input_file_id: The ID of an uploaded file containing requests for the batch.
|
||||||
|
:param endpoint: The endpoint to be used for all requests in the batch.
|
||||||
|
:param completion_window: The time window within which the batch should be processed.
|
||||||
|
:param metadata: Optional metadata for the batch.
|
||||||
|
:returns: The created batch object.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/openai/v1/batches/{batch_id}", method="GET")
|
||||||
|
async def retrieve_batch(self, batch_id: str) -> BatchObject:
|
||||||
|
"""Retrieve information about a specific batch.
|
||||||
|
|
||||||
|
:param batch_id: The ID of the batch to retrieve.
|
||||||
|
:returns: The batch object.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/openai/v1/batches/{batch_id}/cancel", method="POST")
|
||||||
|
async def cancel_batch(self, batch_id: str) -> BatchObject:
|
||||||
|
"""Cancel a batch that is in progress.
|
||||||
|
|
||||||
|
:param batch_id: The ID of the batch to cancel.
|
||||||
|
:returns: The updated batch object.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/openai/v1/batches", method="GET")
|
||||||
|
async def list_batches(
|
||||||
|
self,
|
||||||
|
after: str | None = None,
|
||||||
|
limit: int = 20,
|
||||||
|
) -> ListBatchesResponse:
|
||||||
|
"""List all batches for the current user.
|
||||||
|
|
||||||
|
:param after: A cursor for pagination; returns batches after this batch ID.
|
||||||
|
:param limit: Number of batches to return (default 20, max 100).
|
||||||
|
:returns: A list of batch objects.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
@ -72,3 +72,10 @@ class ModelTypeError(TypeError):
|
||||||
f"Model '{model_name}' is of type '{model_type}' rather than the expected type '{expected_model_type}'"
|
f"Model '{model_name}' is of type '{model_type}' rather than the expected type '{expected_model_type}'"
|
||||||
)
|
)
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ConflictError(ValueError):
|
||||||
|
"""raised when an operation cannot be performed due to a conflict with the current state"""
|
||||||
|
|
||||||
|
def __init__(self, message: str) -> None:
|
||||||
|
super().__init__(message)
|
||||||
|
|
|
||||||
|
|
@ -86,6 +86,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
|
||||||
:cvar inference: Text generation, chat completions, and embeddings
|
:cvar inference: Text generation, chat completions, and embeddings
|
||||||
:cvar safety: Content moderation and safety shields
|
:cvar safety: Content moderation and safety shields
|
||||||
:cvar agents: Agent orchestration and execution
|
:cvar agents: Agent orchestration and execution
|
||||||
|
:cvar batches: Batch processing for asynchronous API requests
|
||||||
:cvar vector_io: Vector database operations and queries
|
:cvar vector_io: Vector database operations and queries
|
||||||
:cvar datasetio: Dataset input/output operations
|
:cvar datasetio: Dataset input/output operations
|
||||||
:cvar scoring: Model output evaluation and scoring
|
:cvar scoring: Model output evaluation and scoring
|
||||||
|
|
@ -108,6 +109,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
|
||||||
inference = "inference"
|
inference = "inference"
|
||||||
safety = "safety"
|
safety = "safety"
|
||||||
agents = "agents"
|
agents = "agents"
|
||||||
|
batches = "batches"
|
||||||
vector_io = "vector_io"
|
vector_io = "vector_io"
|
||||||
datasetio = "datasetio"
|
datasetio = "datasetio"
|
||||||
scoring = "scoring"
|
scoring = "scoring"
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ class OpenAIFilePurpose(StrEnum):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ASSISTANTS = "assistants"
|
ASSISTANTS = "assistants"
|
||||||
|
BATCH = "batch"
|
||||||
# TODO: Add other purposes as needed
|
# TODO: Add other purposes as needed
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ import inspect
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
|
from llama_stack.apis.batches import Batches
|
||||||
from llama_stack.apis.benchmarks import Benchmarks
|
from llama_stack.apis.benchmarks import Benchmarks
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
|
|
@ -75,6 +76,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
|
||||||
Api.agents: Agents,
|
Api.agents: Agents,
|
||||||
Api.inference: Inference,
|
Api.inference: Inference,
|
||||||
Api.inspect: Inspect,
|
Api.inspect: Inspect,
|
||||||
|
Api.batches: Batches,
|
||||||
Api.vector_io: VectorIO,
|
Api.vector_io: VectorIO,
|
||||||
Api.vector_dbs: VectorDBs,
|
Api.vector_dbs: VectorDBs,
|
||||||
Api.models: Models,
|
Api.models: Models,
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,7 @@ from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from openai import BadRequestError
|
from openai import BadRequestError
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
|
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
||||||
from llama_stack.apis.common.responses import PaginatedResponse
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
from llama_stack.cli.utils import add_config_distro_args, get_config_from_args
|
from llama_stack.cli.utils import add_config_distro_args, get_config_from_args
|
||||||
from llama_stack.core.access_control.access_control import AccessDeniedError
|
from llama_stack.core.access_control.access_control import AccessDeniedError
|
||||||
|
|
@ -128,6 +129,10 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
elif isinstance(exc, ConflictError):
|
||||||
|
return HTTPException(status_code=409, detail=str(exc))
|
||||||
|
elif isinstance(exc, ResourceNotFoundError):
|
||||||
|
return HTTPException(status_code=404, detail=str(exc))
|
||||||
elif isinstance(exc, ValueError):
|
elif isinstance(exc, ValueError):
|
||||||
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}")
|
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}")
|
||||||
elif isinstance(exc, BadRequestError):
|
elif isinstance(exc, BadRequestError):
|
||||||
|
|
|
||||||
|
|
@ -48,6 +48,8 @@ distribution_spec:
|
||||||
- provider_type: remote::tavily-search
|
- provider_type: remote::tavily-search
|
||||||
- provider_type: inline::rag-runtime
|
- provider_type: inline::rag-runtime
|
||||||
- provider_type: remote::model-context-protocol
|
- provider_type: remote::model-context-protocol
|
||||||
|
batches:
|
||||||
|
- provider_type: inline::reference
|
||||||
image_type: venv
|
image_type: venv
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
- aiosqlite
|
- aiosqlite
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ version: 2
|
||||||
image_name: ci-tests
|
image_name: ci-tests
|
||||||
apis:
|
apis:
|
||||||
- agents
|
- agents
|
||||||
|
- batches
|
||||||
- datasetio
|
- datasetio
|
||||||
- eval
|
- eval
|
||||||
- files
|
- files
|
||||||
|
|
@ -204,6 +205,13 @@ providers:
|
||||||
provider_type: inline::rag-runtime
|
provider_type: inline::rag-runtime
|
||||||
- provider_id: model-context-protocol
|
- provider_id: model-context-protocol
|
||||||
provider_type: remote::model-context-protocol
|
provider_type: remote::model-context-protocol
|
||||||
|
batches:
|
||||||
|
- provider_id: reference
|
||||||
|
provider_type: inline::reference
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/batches.db
|
||||||
metadata_store:
|
metadata_store:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/registry.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/registry.db
|
||||||
|
|
|
||||||
|
|
@ -48,6 +48,8 @@ distribution_spec:
|
||||||
- provider_type: remote::tavily-search
|
- provider_type: remote::tavily-search
|
||||||
- provider_type: inline::rag-runtime
|
- provider_type: inline::rag-runtime
|
||||||
- provider_type: remote::model-context-protocol
|
- provider_type: remote::model-context-protocol
|
||||||
|
batches:
|
||||||
|
- provider_type: inline::reference
|
||||||
image_type: venv
|
image_type: venv
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
- aiosqlite
|
- aiosqlite
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ version: 2
|
||||||
image_name: starter
|
image_name: starter
|
||||||
apis:
|
apis:
|
||||||
- agents
|
- agents
|
||||||
|
- batches
|
||||||
- datasetio
|
- datasetio
|
||||||
- eval
|
- eval
|
||||||
- files
|
- files
|
||||||
|
|
@ -204,6 +205,13 @@ providers:
|
||||||
provider_type: inline::rag-runtime
|
provider_type: inline::rag-runtime
|
||||||
- provider_id: model-context-protocol
|
- provider_id: model-context-protocol
|
||||||
provider_type: remote::model-context-protocol
|
provider_type: remote::model-context-protocol
|
||||||
|
batches:
|
||||||
|
- provider_id: reference
|
||||||
|
provider_type: inline::reference
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/batches.db
|
||||||
metadata_store:
|
metadata_store:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/registry.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/registry.db
|
||||||
|
|
|
||||||
|
|
@ -139,6 +139,9 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
BuildProvider(provider_type="inline::rag-runtime"),
|
BuildProvider(provider_type="inline::rag-runtime"),
|
||||||
BuildProvider(provider_type="remote::model-context-protocol"),
|
BuildProvider(provider_type="remote::model-context-protocol"),
|
||||||
],
|
],
|
||||||
|
"batches": [
|
||||||
|
BuildProvider(provider_type="inline::reference"),
|
||||||
|
],
|
||||||
}
|
}
|
||||||
files_provider = Provider(
|
files_provider = Provider(
|
||||||
provider_id="meta-reference-files",
|
provider_id="meta-reference-files",
|
||||||
|
|
|
||||||
|
|
@ -48,8 +48,8 @@ from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||||
|
|
||||||
from .agent_instance import ChatAgent
|
from .agent_instance import ChatAgent
|
||||||
from .config import MetaReferenceAgentsImplConfig
|
from .config import MetaReferenceAgentsImplConfig
|
||||||
from .openai_responses import OpenAIResponsesImpl
|
|
||||||
from .persistence import AgentInfo
|
from .persistence import AgentInfo
|
||||||
|
from .responses.openai_responses import OpenAIResponsesImpl
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
@ -0,0 +1,271 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.apis.agents import Order
|
||||||
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
|
ListOpenAIResponseInputItem,
|
||||||
|
ListOpenAIResponseObject,
|
||||||
|
OpenAIDeleteResponseObject,
|
||||||
|
OpenAIResponseInput,
|
||||||
|
OpenAIResponseInputMessageContentText,
|
||||||
|
OpenAIResponseInputTool,
|
||||||
|
OpenAIResponseMessage,
|
||||||
|
OpenAIResponseObject,
|
||||||
|
OpenAIResponseObjectStream,
|
||||||
|
OpenAIResponseText,
|
||||||
|
OpenAIResponseTextFormat,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
Inference,
|
||||||
|
OpenAISystemMessageParam,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||||
|
|
||||||
|
from .streaming import StreamingResponseOrchestrator
|
||||||
|
from .tool_executor import ToolExecutor
|
||||||
|
from .types import ChatCompletionContext
|
||||||
|
from .utils import (
|
||||||
|
convert_response_input_to_chat_messages,
|
||||||
|
convert_response_text_to_chat_response_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="responses")
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
||||||
|
input_items: ListOpenAIResponseInputItem
|
||||||
|
response: OpenAIResponseObject
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIResponsesImpl:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
inference_api: Inference,
|
||||||
|
tool_groups_api: ToolGroups,
|
||||||
|
tool_runtime_api: ToolRuntime,
|
||||||
|
responses_store: ResponsesStore,
|
||||||
|
vector_io_api: VectorIO, # VectorIO
|
||||||
|
):
|
||||||
|
self.inference_api = inference_api
|
||||||
|
self.tool_groups_api = tool_groups_api
|
||||||
|
self.tool_runtime_api = tool_runtime_api
|
||||||
|
self.responses_store = responses_store
|
||||||
|
self.vector_io_api = vector_io_api
|
||||||
|
self.tool_executor = ToolExecutor(
|
||||||
|
tool_groups_api=tool_groups_api,
|
||||||
|
tool_runtime_api=tool_runtime_api,
|
||||||
|
vector_io_api=vector_io_api,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _prepend_previous_response(
|
||||||
|
self,
|
||||||
|
input: str | list[OpenAIResponseInput],
|
||||||
|
previous_response_id: str | None = None,
|
||||||
|
):
|
||||||
|
if previous_response_id:
|
||||||
|
previous_response_with_input = await self.responses_store.get_response_object(previous_response_id)
|
||||||
|
|
||||||
|
# previous response input items
|
||||||
|
new_input_items = previous_response_with_input.input
|
||||||
|
|
||||||
|
# previous response output items
|
||||||
|
new_input_items.extend(previous_response_with_input.output)
|
||||||
|
|
||||||
|
# new input items from the current request
|
||||||
|
if isinstance(input, str):
|
||||||
|
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
|
||||||
|
else:
|
||||||
|
new_input_items.extend(input)
|
||||||
|
|
||||||
|
input = new_input_items
|
||||||
|
|
||||||
|
return input
|
||||||
|
|
||||||
|
async def _prepend_instructions(self, messages, instructions):
|
||||||
|
if instructions:
|
||||||
|
messages.insert(0, OpenAISystemMessageParam(content=instructions))
|
||||||
|
|
||||||
|
async def get_openai_response(
|
||||||
|
self,
|
||||||
|
response_id: str,
|
||||||
|
) -> OpenAIResponseObject:
|
||||||
|
response_with_input = await self.responses_store.get_response_object(response_id)
|
||||||
|
return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"})
|
||||||
|
|
||||||
|
async def list_openai_responses(
|
||||||
|
self,
|
||||||
|
after: str | None = None,
|
||||||
|
limit: int | None = 50,
|
||||||
|
model: str | None = None,
|
||||||
|
order: Order | None = Order.desc,
|
||||||
|
) -> ListOpenAIResponseObject:
|
||||||
|
return await self.responses_store.list_responses(after, limit, model, order)
|
||||||
|
|
||||||
|
async def list_openai_response_input_items(
|
||||||
|
self,
|
||||||
|
response_id: str,
|
||||||
|
after: str | None = None,
|
||||||
|
before: str | None = None,
|
||||||
|
include: list[str] | None = None,
|
||||||
|
limit: int | None = 20,
|
||||||
|
order: Order | None = Order.desc,
|
||||||
|
) -> ListOpenAIResponseInputItem:
|
||||||
|
"""List input items for a given OpenAI response.
|
||||||
|
|
||||||
|
:param response_id: The ID of the response to retrieve input items for.
|
||||||
|
:param after: An item ID to list items after, used for pagination.
|
||||||
|
:param before: An item ID to list items before, used for pagination.
|
||||||
|
:param include: Additional fields to include in the response.
|
||||||
|
:param limit: A limit on the number of objects to be returned.
|
||||||
|
:param order: The order to return the input items in.
|
||||||
|
:returns: An ListOpenAIResponseInputItem.
|
||||||
|
"""
|
||||||
|
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
|
||||||
|
|
||||||
|
async def _store_response(
|
||||||
|
self,
|
||||||
|
response: OpenAIResponseObject,
|
||||||
|
input: str | list[OpenAIResponseInput],
|
||||||
|
) -> None:
|
||||||
|
new_input_id = f"msg_{uuid.uuid4()}"
|
||||||
|
if isinstance(input, str):
|
||||||
|
# synthesize a message from the input string
|
||||||
|
input_content = OpenAIResponseInputMessageContentText(text=input)
|
||||||
|
input_content_item = OpenAIResponseMessage(
|
||||||
|
role="user",
|
||||||
|
content=[input_content],
|
||||||
|
id=new_input_id,
|
||||||
|
)
|
||||||
|
input_items_data = [input_content_item]
|
||||||
|
else:
|
||||||
|
# we already have a list of messages
|
||||||
|
input_items_data = []
|
||||||
|
for input_item in input:
|
||||||
|
if isinstance(input_item, OpenAIResponseMessage):
|
||||||
|
# These may or may not already have an id, so dump to dict, check for id, and add if missing
|
||||||
|
input_item_dict = input_item.model_dump()
|
||||||
|
if "id" not in input_item_dict:
|
||||||
|
input_item_dict["id"] = new_input_id
|
||||||
|
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
|
||||||
|
else:
|
||||||
|
input_items_data.append(input_item)
|
||||||
|
|
||||||
|
await self.responses_store.store_response_object(
|
||||||
|
response_object=response,
|
||||||
|
input=input_items_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def create_openai_response(
|
||||||
|
self,
|
||||||
|
input: str | list[OpenAIResponseInput],
|
||||||
|
model: str,
|
||||||
|
instructions: str | None = None,
|
||||||
|
previous_response_id: str | None = None,
|
||||||
|
store: bool | None = True,
|
||||||
|
stream: bool | None = False,
|
||||||
|
temperature: float | None = None,
|
||||||
|
text: OpenAIResponseText | None = None,
|
||||||
|
tools: list[OpenAIResponseInputTool] | None = None,
|
||||||
|
include: list[str] | None = None,
|
||||||
|
max_infer_iters: int | None = 10,
|
||||||
|
):
|
||||||
|
stream = bool(stream)
|
||||||
|
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||||
|
|
||||||
|
stream_gen = self._create_streaming_response(
|
||||||
|
input=input,
|
||||||
|
model=model,
|
||||||
|
instructions=instructions,
|
||||||
|
previous_response_id=previous_response_id,
|
||||||
|
store=store,
|
||||||
|
temperature=temperature,
|
||||||
|
text=text,
|
||||||
|
tools=tools,
|
||||||
|
max_infer_iters=max_infer_iters,
|
||||||
|
)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return stream_gen
|
||||||
|
else:
|
||||||
|
response = None
|
||||||
|
async for stream_chunk in stream_gen:
|
||||||
|
if stream_chunk.type == "response.completed":
|
||||||
|
if response is not None:
|
||||||
|
raise ValueError("The response stream completed multiple times! Earlier response: {response}")
|
||||||
|
response = stream_chunk.response
|
||||||
|
# don't leave the generator half complete!
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
raise ValueError("The response stream never completed")
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def _create_streaming_response(
|
||||||
|
self,
|
||||||
|
input: str | list[OpenAIResponseInput],
|
||||||
|
model: str,
|
||||||
|
instructions: str | None = None,
|
||||||
|
previous_response_id: str | None = None,
|
||||||
|
store: bool | None = True,
|
||||||
|
temperature: float | None = None,
|
||||||
|
text: OpenAIResponseText | None = None,
|
||||||
|
tools: list[OpenAIResponseInputTool] | None = None,
|
||||||
|
max_infer_iters: int | None = 10,
|
||||||
|
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
|
# Input preprocessing
|
||||||
|
input = await self._prepend_previous_response(input, previous_response_id)
|
||||||
|
messages = await convert_response_input_to_chat_messages(input)
|
||||||
|
await self._prepend_instructions(messages, instructions)
|
||||||
|
|
||||||
|
# Structured outputs
|
||||||
|
response_format = await convert_response_text_to_chat_response_format(text)
|
||||||
|
|
||||||
|
ctx = ChatCompletionContext(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
response_tools=tools,
|
||||||
|
temperature=temperature,
|
||||||
|
response_format=response_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create orchestrator and delegate streaming logic
|
||||||
|
response_id = f"resp-{uuid.uuid4()}"
|
||||||
|
created_at = int(time.time())
|
||||||
|
|
||||||
|
orchestrator = StreamingResponseOrchestrator(
|
||||||
|
inference_api=self.inference_api,
|
||||||
|
ctx=ctx,
|
||||||
|
response_id=response_id,
|
||||||
|
created_at=created_at,
|
||||||
|
text=text,
|
||||||
|
max_infer_iters=max_infer_iters,
|
||||||
|
tool_executor=self.tool_executor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stream the response
|
||||||
|
final_response = None
|
||||||
|
async for stream_chunk in orchestrator.create_response():
|
||||||
|
if stream_chunk.type == "response.completed":
|
||||||
|
final_response = stream_chunk.response
|
||||||
|
yield stream_chunk
|
||||||
|
|
||||||
|
# Store the response if requested
|
||||||
|
if store and final_response:
|
||||||
|
await self._store_response(
|
||||||
|
response=final_response,
|
||||||
|
input=input,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||||
|
return await self.responses_store.delete_response_object(response_id)
|
||||||
|
|
@ -0,0 +1,634 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
|
AllowedToolsFilter,
|
||||||
|
MCPListToolsTool,
|
||||||
|
OpenAIResponseContentPartOutputText,
|
||||||
|
OpenAIResponseInputTool,
|
||||||
|
OpenAIResponseInputToolMCP,
|
||||||
|
OpenAIResponseObject,
|
||||||
|
OpenAIResponseObjectStream,
|
||||||
|
OpenAIResponseObjectStreamResponseCompleted,
|
||||||
|
OpenAIResponseObjectStreamResponseContentPartAdded,
|
||||||
|
OpenAIResponseObjectStreamResponseContentPartDone,
|
||||||
|
OpenAIResponseObjectStreamResponseCreated,
|
||||||
|
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta,
|
||||||
|
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone,
|
||||||
|
OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta,
|
||||||
|
OpenAIResponseObjectStreamResponseMcpCallArgumentsDone,
|
||||||
|
OpenAIResponseObjectStreamResponseMcpListToolsCompleted,
|
||||||
|
OpenAIResponseObjectStreamResponseMcpListToolsInProgress,
|
||||||
|
OpenAIResponseObjectStreamResponseOutputItemAdded,
|
||||||
|
OpenAIResponseObjectStreamResponseOutputItemDone,
|
||||||
|
OpenAIResponseObjectStreamResponseOutputTextDelta,
|
||||||
|
OpenAIResponseOutput,
|
||||||
|
OpenAIResponseOutputMessageFunctionToolCall,
|
||||||
|
OpenAIResponseOutputMessageMCPListTools,
|
||||||
|
OpenAIResponseText,
|
||||||
|
WebSearchToolTypes,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
Inference,
|
||||||
|
OpenAIAssistantMessageParam,
|
||||||
|
OpenAIChatCompletion,
|
||||||
|
OpenAIChatCompletionToolCall,
|
||||||
|
OpenAIChoice,
|
||||||
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from .types import ChatCompletionContext, ChatCompletionResult
|
||||||
|
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="responses")
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingResponseOrchestrator:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
inference_api: Inference,
|
||||||
|
ctx: ChatCompletionContext,
|
||||||
|
response_id: str,
|
||||||
|
created_at: int,
|
||||||
|
text: OpenAIResponseText,
|
||||||
|
max_infer_iters: int,
|
||||||
|
tool_executor, # Will be the tool execution logic from the main class
|
||||||
|
):
|
||||||
|
self.inference_api = inference_api
|
||||||
|
self.ctx = ctx
|
||||||
|
self.response_id = response_id
|
||||||
|
self.created_at = created_at
|
||||||
|
self.text = text
|
||||||
|
self.max_infer_iters = max_infer_iters
|
||||||
|
self.tool_executor = tool_executor
|
||||||
|
self.sequence_number = 0
|
||||||
|
# Store MCP tool mapping that gets built during tool processing
|
||||||
|
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
|
||||||
|
|
||||||
|
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
|
# Initialize output messages
|
||||||
|
output_messages: list[OpenAIResponseOutput] = []
|
||||||
|
# Create initial response and emit response.created immediately
|
||||||
|
initial_response = OpenAIResponseObject(
|
||||||
|
created_at=self.created_at,
|
||||||
|
id=self.response_id,
|
||||||
|
model=self.ctx.model,
|
||||||
|
object="response",
|
||||||
|
status="in_progress",
|
||||||
|
output=output_messages.copy(),
|
||||||
|
text=self.text,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
|
||||||
|
|
||||||
|
# Process all tools (including MCP tools) and emit streaming events
|
||||||
|
if self.ctx.response_tools:
|
||||||
|
async for stream_event in self._process_tools(self.ctx.response_tools, output_messages):
|
||||||
|
yield stream_event
|
||||||
|
|
||||||
|
n_iter = 0
|
||||||
|
messages = self.ctx.messages.copy()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
completion_result = await self.inference_api.openai_chat_completion(
|
||||||
|
model=self.ctx.model,
|
||||||
|
messages=messages,
|
||||||
|
tools=self.ctx.chat_tools,
|
||||||
|
stream=True,
|
||||||
|
temperature=self.ctx.temperature,
|
||||||
|
response_format=self.ctx.response_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process streaming chunks and build complete response
|
||||||
|
completion_result_data = None
|
||||||
|
async for stream_event_or_result in self._process_streaming_chunks(completion_result, output_messages):
|
||||||
|
if isinstance(stream_event_or_result, ChatCompletionResult):
|
||||||
|
completion_result_data = stream_event_or_result
|
||||||
|
else:
|
||||||
|
yield stream_event_or_result
|
||||||
|
if not completion_result_data:
|
||||||
|
raise ValueError("Streaming chunk processor failed to return completion data")
|
||||||
|
current_response = self._build_chat_completion(completion_result_data)
|
||||||
|
|
||||||
|
function_tool_calls, non_function_tool_calls, next_turn_messages = self._separate_tool_calls(
|
||||||
|
current_response, messages
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle choices with no tool calls
|
||||||
|
for choice in current_response.choices:
|
||||||
|
if not (choice.message.tool_calls and self.ctx.response_tools):
|
||||||
|
output_messages.append(await convert_chat_choice_to_response_message(choice))
|
||||||
|
|
||||||
|
# Execute tool calls and coordinate results
|
||||||
|
async for stream_event in self._coordinate_tool_execution(
|
||||||
|
function_tool_calls,
|
||||||
|
non_function_tool_calls,
|
||||||
|
completion_result_data,
|
||||||
|
output_messages,
|
||||||
|
next_turn_messages,
|
||||||
|
):
|
||||||
|
yield stream_event
|
||||||
|
|
||||||
|
if not function_tool_calls and not non_function_tool_calls:
|
||||||
|
break
|
||||||
|
|
||||||
|
if function_tool_calls:
|
||||||
|
logger.info("Exiting inference loop since there is a function (client-side) tool call")
|
||||||
|
break
|
||||||
|
|
||||||
|
n_iter += 1
|
||||||
|
if n_iter >= self.max_infer_iters:
|
||||||
|
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}")
|
||||||
|
break
|
||||||
|
|
||||||
|
messages = next_turn_messages
|
||||||
|
|
||||||
|
# Create final response
|
||||||
|
final_response = OpenAIResponseObject(
|
||||||
|
created_at=self.created_at,
|
||||||
|
id=self.response_id,
|
||||||
|
model=self.ctx.model,
|
||||||
|
object="response",
|
||||||
|
status="completed",
|
||||||
|
text=self.text,
|
||||||
|
output=output_messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Emit response.completed
|
||||||
|
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
|
||||||
|
|
||||||
|
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list]:
|
||||||
|
"""Separate tool calls into function and non-function categories."""
|
||||||
|
function_tool_calls = []
|
||||||
|
non_function_tool_calls = []
|
||||||
|
next_turn_messages = messages.copy()
|
||||||
|
|
||||||
|
for choice in current_response.choices:
|
||||||
|
next_turn_messages.append(choice.message)
|
||||||
|
|
||||||
|
if choice.message.tool_calls and self.ctx.response_tools:
|
||||||
|
for tool_call in choice.message.tool_calls:
|
||||||
|
if is_function_tool_call(tool_call, self.ctx.response_tools):
|
||||||
|
function_tool_calls.append(tool_call)
|
||||||
|
else:
|
||||||
|
non_function_tool_calls.append(tool_call)
|
||||||
|
|
||||||
|
return function_tool_calls, non_function_tool_calls, next_turn_messages
|
||||||
|
|
||||||
|
async def _process_streaming_chunks(
|
||||||
|
self, completion_result, output_messages: list[OpenAIResponseOutput]
|
||||||
|
) -> AsyncIterator[OpenAIResponseObjectStream | ChatCompletionResult]:
|
||||||
|
"""Process streaming chunks and emit events, returning completion data."""
|
||||||
|
# Initialize result tracking
|
||||||
|
chat_response_id = ""
|
||||||
|
chat_response_content = []
|
||||||
|
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
|
||||||
|
chunk_created = 0
|
||||||
|
chunk_model = ""
|
||||||
|
chunk_finish_reason = ""
|
||||||
|
|
||||||
|
# Create a placeholder message item for delta events
|
||||||
|
message_item_id = f"msg_{uuid.uuid4()}"
|
||||||
|
# Track tool call items for streaming events
|
||||||
|
tool_call_item_ids: dict[int, str] = {}
|
||||||
|
# Track content parts for streaming events
|
||||||
|
content_part_emitted = False
|
||||||
|
|
||||||
|
async for chunk in completion_result:
|
||||||
|
chat_response_id = chunk.id
|
||||||
|
chunk_created = chunk.created
|
||||||
|
chunk_model = chunk.model
|
||||||
|
for chunk_choice in chunk.choices:
|
||||||
|
# Emit incremental text content as delta events
|
||||||
|
if chunk_choice.delta.content:
|
||||||
|
# Emit content_part.added event for first text chunk
|
||||||
|
if not content_part_emitted:
|
||||||
|
content_part_emitted = True
|
||||||
|
self.sequence_number += 1
|
||||||
|
yield OpenAIResponseObjectStreamResponseContentPartAdded(
|
||||||
|
response_id=self.response_id,
|
||||||
|
item_id=message_item_id,
|
||||||
|
part=OpenAIResponseContentPartOutputText(
|
||||||
|
text="", # Will be filled incrementally via text deltas
|
||||||
|
),
|
||||||
|
sequence_number=self.sequence_number,
|
||||||
|
)
|
||||||
|
self.sequence_number += 1
|
||||||
|
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
|
||||||
|
content_index=0,
|
||||||
|
delta=chunk_choice.delta.content,
|
||||||
|
item_id=message_item_id,
|
||||||
|
output_index=0,
|
||||||
|
sequence_number=self.sequence_number,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect content for final response
|
||||||
|
chat_response_content.append(chunk_choice.delta.content or "")
|
||||||
|
if chunk_choice.finish_reason:
|
||||||
|
chunk_finish_reason = chunk_choice.finish_reason
|
||||||
|
|
||||||
|
# Aggregate tool call arguments across chunks
|
||||||
|
if chunk_choice.delta.tool_calls:
|
||||||
|
for tool_call in chunk_choice.delta.tool_calls:
|
||||||
|
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
|
||||||
|
# Create new tool call entry if this is the first chunk for this index
|
||||||
|
is_new_tool_call = response_tool_call is None
|
||||||
|
if is_new_tool_call:
|
||||||
|
tool_call_dict: dict[str, Any] = tool_call.model_dump()
|
||||||
|
tool_call_dict.pop("type", None)
|
||||||
|
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
|
||||||
|
chat_response_tool_calls[tool_call.index] = response_tool_call
|
||||||
|
|
||||||
|
# Create item ID for this tool call for streaming events
|
||||||
|
tool_call_item_id = f"fc_{uuid.uuid4()}"
|
||||||
|
tool_call_item_ids[tool_call.index] = tool_call_item_id
|
||||||
|
|
||||||
|
# Emit output_item.added event for the new function call
|
||||||
|
self.sequence_number += 1
|
||||||
|
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
|
||||||
|
arguments="", # Will be filled incrementally via delta events
|
||||||
|
call_id=tool_call.id or "",
|
||||||
|
name=tool_call.function.name if tool_call.function else "",
|
||||||
|
id=tool_call_item_id,
|
||||||
|
status="in_progress",
|
||||||
|
)
|
||||||
|
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||||
|
response_id=self.response_id,
|
||||||
|
item=function_call_item,
|
||||||
|
output_index=len(output_messages),
|
||||||
|
sequence_number=self.sequence_number,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stream tool call arguments as they arrive (differentiate between MCP and function calls)
|
||||||
|
if tool_call.function and tool_call.function.arguments:
|
||||||
|
tool_call_item_id = tool_call_item_ids[tool_call.index]
|
||||||
|
self.sequence_number += 1
|
||||||
|
|
||||||
|
# Check if this is an MCP tool call
|
||||||
|
is_mcp_tool = tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server
|
||||||
|
if is_mcp_tool:
|
||||||
|
# Emit MCP-specific argument delta event
|
||||||
|
yield OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta(
|
||||||
|
delta=tool_call.function.arguments,
|
||||||
|
item_id=tool_call_item_id,
|
||||||
|
output_index=len(output_messages),
|
||||||
|
sequence_number=self.sequence_number,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Emit function call argument delta event
|
||||||
|
yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(
|
||||||
|
delta=tool_call.function.arguments,
|
||||||
|
item_id=tool_call_item_id,
|
||||||
|
output_index=len(output_messages),
|
||||||
|
sequence_number=self.sequence_number,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Accumulate arguments for final response (only for subsequent chunks)
|
||||||
|
if not is_new_tool_call:
|
||||||
|
response_tool_call.function.arguments = (
|
||||||
|
response_tool_call.function.arguments or ""
|
||||||
|
) + tool_call.function.arguments
|
||||||
|
|
||||||
|
# Emit arguments.done events for completed tool calls (differentiate between MCP and function calls)
|
||||||
|
for tool_call_index in sorted(chat_response_tool_calls.keys()):
|
||||||
|
tool_call_item_id = tool_call_item_ids[tool_call_index]
|
||||||
|
final_arguments = chat_response_tool_calls[tool_call_index].function.arguments or ""
|
||||||
|
tool_call_name = chat_response_tool_calls[tool_call_index].function.name
|
||||||
|
|
||||||
|
# Check if this is an MCP tool call
|
||||||
|
is_mcp_tool = tool_call_name and tool_call_name in self.mcp_tool_to_server
|
||||||
|
self.sequence_number += 1
|
||||||
|
done_event_cls = (
|
||||||
|
OpenAIResponseObjectStreamResponseMcpCallArgumentsDone
|
||||||
|
if is_mcp_tool
|
||||||
|
else OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone
|
||||||
|
)
|
||||||
|
yield done_event_cls(
|
||||||
|
arguments=final_arguments,
|
||||||
|
item_id=tool_call_item_id,
|
||||||
|
output_index=len(output_messages),
|
||||||
|
sequence_number=self.sequence_number,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Emit content_part.done event if text content was streamed (before content gets cleared)
|
||||||
|
if content_part_emitted:
|
||||||
|
final_text = "".join(chat_response_content)
|
||||||
|
self.sequence_number += 1
|
||||||
|
yield OpenAIResponseObjectStreamResponseContentPartDone(
|
||||||
|
response_id=self.response_id,
|
||||||
|
item_id=message_item_id,
|
||||||
|
part=OpenAIResponseContentPartOutputText(
|
||||||
|
text=final_text,
|
||||||
|
),
|
||||||
|
sequence_number=self.sequence_number,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clear content when there are tool calls (OpenAI spec behavior)
|
||||||
|
if chat_response_tool_calls:
|
||||||
|
chat_response_content = []
|
||||||
|
|
||||||
|
yield ChatCompletionResult(
|
||||||
|
response_id=chat_response_id,
|
||||||
|
content=chat_response_content,
|
||||||
|
tool_calls=chat_response_tool_calls,
|
||||||
|
created=chunk_created,
|
||||||
|
model=chunk_model,
|
||||||
|
finish_reason=chunk_finish_reason,
|
||||||
|
message_item_id=message_item_id,
|
||||||
|
tool_call_item_ids=tool_call_item_ids,
|
||||||
|
content_part_emitted=content_part_emitted,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_chat_completion(self, result: ChatCompletionResult) -> OpenAIChatCompletion:
|
||||||
|
"""Build OpenAIChatCompletion from ChatCompletionResult."""
|
||||||
|
# Convert collected chunks to complete response
|
||||||
|
if result.tool_calls:
|
||||||
|
tool_calls = [result.tool_calls[i] for i in sorted(result.tool_calls.keys())]
|
||||||
|
else:
|
||||||
|
tool_calls = None
|
||||||
|
|
||||||
|
assistant_message = OpenAIAssistantMessageParam(
|
||||||
|
content=result.content_text,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
)
|
||||||
|
return OpenAIChatCompletion(
|
||||||
|
id=result.response_id,
|
||||||
|
choices=[
|
||||||
|
OpenAIChoice(
|
||||||
|
message=assistant_message,
|
||||||
|
finish_reason=result.finish_reason,
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=result.created,
|
||||||
|
model=result.model,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _coordinate_tool_execution(
|
||||||
|
self,
|
||||||
|
function_tool_calls: list,
|
||||||
|
non_function_tool_calls: list,
|
||||||
|
completion_result_data: ChatCompletionResult,
|
||||||
|
output_messages: list[OpenAIResponseOutput],
|
||||||
|
next_turn_messages: list,
|
||||||
|
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
|
"""Coordinate execution of both function and non-function tool calls."""
|
||||||
|
# Execute non-function tool calls
|
||||||
|
for tool_call in non_function_tool_calls:
|
||||||
|
# Find the item_id for this tool call
|
||||||
|
matching_item_id = None
|
||||||
|
for index, item_id in completion_result_data.tool_call_item_ids.items():
|
||||||
|
response_tool_call = completion_result_data.tool_calls.get(index)
|
||||||
|
if response_tool_call and response_tool_call.id == tool_call.id:
|
||||||
|
matching_item_id = item_id
|
||||||
|
break
|
||||||
|
|
||||||
|
# Use a fallback item_id if not found
|
||||||
|
if not matching_item_id:
|
||||||
|
matching_item_id = f"tc_{uuid.uuid4()}"
|
||||||
|
|
||||||
|
# Execute tool call with streaming
|
||||||
|
tool_call_log = None
|
||||||
|
tool_response_message = None
|
||||||
|
async for result in self.tool_executor.execute_tool_call(
|
||||||
|
tool_call,
|
||||||
|
self.ctx,
|
||||||
|
self.sequence_number,
|
||||||
|
len(output_messages),
|
||||||
|
matching_item_id,
|
||||||
|
self.mcp_tool_to_server,
|
||||||
|
):
|
||||||
|
if result.stream_event:
|
||||||
|
# Forward streaming events
|
||||||
|
self.sequence_number = result.sequence_number
|
||||||
|
yield result.stream_event
|
||||||
|
|
||||||
|
if result.final_output_message is not None:
|
||||||
|
tool_call_log = result.final_output_message
|
||||||
|
tool_response_message = result.final_input_message
|
||||||
|
self.sequence_number = result.sequence_number
|
||||||
|
|
||||||
|
if tool_call_log:
|
||||||
|
output_messages.append(tool_call_log)
|
||||||
|
|
||||||
|
# Emit output_item.done event for completed non-function tool call
|
||||||
|
if matching_item_id:
|
||||||
|
self.sequence_number += 1
|
||||||
|
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||||
|
response_id=self.response_id,
|
||||||
|
item=tool_call_log,
|
||||||
|
output_index=len(output_messages) - 1,
|
||||||
|
sequence_number=self.sequence_number,
|
||||||
|
)
|
||||||
|
|
||||||
|
if tool_response_message:
|
||||||
|
next_turn_messages.append(tool_response_message)
|
||||||
|
|
||||||
|
# Execute function tool calls (client-side)
|
||||||
|
for tool_call in function_tool_calls:
|
||||||
|
# Find the item_id for this tool call from our tracking dictionary
|
||||||
|
matching_item_id = None
|
||||||
|
for index, item_id in completion_result_data.tool_call_item_ids.items():
|
||||||
|
response_tool_call = completion_result_data.tool_calls.get(index)
|
||||||
|
if response_tool_call and response_tool_call.id == tool_call.id:
|
||||||
|
matching_item_id = item_id
|
||||||
|
break
|
||||||
|
|
||||||
|
# Use existing item_id or create new one if not found
|
||||||
|
final_item_id = matching_item_id or f"fc_{uuid.uuid4()}"
|
||||||
|
|
||||||
|
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
|
||||||
|
arguments=tool_call.function.arguments or "",
|
||||||
|
call_id=tool_call.id,
|
||||||
|
name=tool_call.function.name or "",
|
||||||
|
id=final_item_id,
|
||||||
|
status="completed",
|
||||||
|
)
|
||||||
|
output_messages.append(function_call_item)
|
||||||
|
|
||||||
|
# Emit output_item.done event for completed function call
|
||||||
|
self.sequence_number += 1
|
||||||
|
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||||
|
response_id=self.response_id,
|
||||||
|
item=function_call_item,
|
||||||
|
output_index=len(output_messages) - 1,
|
||||||
|
sequence_number=self.sequence_number,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _process_tools(
|
||||||
|
self, tools: list[OpenAIResponseInputTool], output_messages: list[OpenAIResponseOutput]
|
||||||
|
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
|
"""Process all tools and emit appropriate streaming events."""
|
||||||
|
from openai.types.chat import ChatCompletionToolParam
|
||||||
|
|
||||||
|
from llama_stack.apis.tools import Tool
|
||||||
|
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||||
|
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||||
|
|
||||||
|
def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
|
||||||
|
tool_def = ToolDefinition(
|
||||||
|
tool_name=tool_name,
|
||||||
|
description=tool.description,
|
||||||
|
parameters={
|
||||||
|
param.name: ToolParamDefinition(
|
||||||
|
param_type=param.parameter_type,
|
||||||
|
description=param.description,
|
||||||
|
required=param.required,
|
||||||
|
default=param.default,
|
||||||
|
)
|
||||||
|
for param in tool.parameters
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return convert_tooldef_to_openai_tool(tool_def)
|
||||||
|
|
||||||
|
# Initialize chat_tools if not already set
|
||||||
|
if self.ctx.chat_tools is None:
|
||||||
|
self.ctx.chat_tools = []
|
||||||
|
|
||||||
|
for input_tool in tools:
|
||||||
|
if input_tool.type == "function":
|
||||||
|
self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
|
||||||
|
elif input_tool.type in WebSearchToolTypes:
|
||||||
|
tool_name = "web_search"
|
||||||
|
# Need to access tool_groups_api from tool_executor
|
||||||
|
tool = await self.tool_executor.tool_groups_api.get_tool(tool_name)
|
||||||
|
if not tool:
|
||||||
|
raise ValueError(f"Tool {tool_name} not found")
|
||||||
|
self.ctx.chat_tools.append(make_openai_tool(tool_name, tool))
|
||||||
|
elif input_tool.type == "file_search":
|
||||||
|
tool_name = "knowledge_search"
|
||||||
|
tool = await self.tool_executor.tool_groups_api.get_tool(tool_name)
|
||||||
|
if not tool:
|
||||||
|
raise ValueError(f"Tool {tool_name} not found")
|
||||||
|
self.ctx.chat_tools.append(make_openai_tool(tool_name, tool))
|
||||||
|
elif input_tool.type == "mcp":
|
||||||
|
async for stream_event in self._process_mcp_tool(input_tool, output_messages):
|
||||||
|
yield stream_event
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
||||||
|
|
||||||
|
async def _process_mcp_tool(
|
||||||
|
self, mcp_tool: OpenAIResponseInputToolMCP, output_messages: list[OpenAIResponseOutput]
|
||||||
|
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
|
"""Process an MCP tool configuration and emit appropriate streaming events."""
|
||||||
|
from llama_stack.providers.utils.tools.mcp import list_mcp_tools
|
||||||
|
|
||||||
|
# Emit mcp_list_tools.in_progress
|
||||||
|
self.sequence_number += 1
|
||||||
|
yield OpenAIResponseObjectStreamResponseMcpListToolsInProgress(
|
||||||
|
sequence_number=self.sequence_number,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Parse allowed/never allowed tools
|
||||||
|
always_allowed = None
|
||||||
|
never_allowed = None
|
||||||
|
if mcp_tool.allowed_tools:
|
||||||
|
if isinstance(mcp_tool.allowed_tools, list):
|
||||||
|
always_allowed = mcp_tool.allowed_tools
|
||||||
|
elif isinstance(mcp_tool.allowed_tools, AllowedToolsFilter):
|
||||||
|
always_allowed = mcp_tool.allowed_tools.always
|
||||||
|
never_allowed = mcp_tool.allowed_tools.never
|
||||||
|
|
||||||
|
# Call list_mcp_tools
|
||||||
|
tool_defs = await list_mcp_tools(
|
||||||
|
endpoint=mcp_tool.server_url,
|
||||||
|
headers=mcp_tool.headers or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the MCP list tools message
|
||||||
|
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
||||||
|
id=f"mcp_list_{uuid.uuid4()}",
|
||||||
|
server_label=mcp_tool.server_label,
|
||||||
|
tools=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process tools and update context
|
||||||
|
for t in tool_defs.data:
|
||||||
|
if never_allowed and t.name in never_allowed:
|
||||||
|
continue
|
||||||
|
if not always_allowed or t.name in always_allowed:
|
||||||
|
# Add to chat tools for inference
|
||||||
|
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||||
|
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||||
|
|
||||||
|
tool_def = ToolDefinition(
|
||||||
|
tool_name=t.name,
|
||||||
|
description=t.description,
|
||||||
|
parameters={
|
||||||
|
param.name: ToolParamDefinition(
|
||||||
|
param_type=param.parameter_type,
|
||||||
|
description=param.description,
|
||||||
|
required=param.required,
|
||||||
|
default=param.default,
|
||||||
|
)
|
||||||
|
for param in t.parameters
|
||||||
|
},
|
||||||
|
)
|
||||||
|
openai_tool = convert_tooldef_to_openai_tool(tool_def)
|
||||||
|
if self.ctx.chat_tools is None:
|
||||||
|
self.ctx.chat_tools = []
|
||||||
|
self.ctx.chat_tools.append(openai_tool)
|
||||||
|
|
||||||
|
# Add to MCP tool mapping
|
||||||
|
if t.name in self.mcp_tool_to_server:
|
||||||
|
raise ValueError(f"Duplicate tool name {t.name} found for server {mcp_tool.server_label}")
|
||||||
|
self.mcp_tool_to_server[t.name] = mcp_tool
|
||||||
|
|
||||||
|
# Add to MCP list message
|
||||||
|
mcp_list_message.tools.append(
|
||||||
|
MCPListToolsTool(
|
||||||
|
name=t.name,
|
||||||
|
description=t.description,
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
p.name: {
|
||||||
|
"type": p.parameter_type,
|
||||||
|
"description": p.description,
|
||||||
|
}
|
||||||
|
for p in t.parameters
|
||||||
|
},
|
||||||
|
"required": [p.name for p in t.parameters if p.required],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the MCP list message to output
|
||||||
|
output_messages.append(mcp_list_message)
|
||||||
|
|
||||||
|
# Emit output_item.added for the MCP list tools message
|
||||||
|
self.sequence_number += 1
|
||||||
|
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||||
|
response_id=self.response_id,
|
||||||
|
item=mcp_list_message,
|
||||||
|
output_index=len(output_messages) - 1,
|
||||||
|
sequence_number=self.sequence_number,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Emit mcp_list_tools.completed
|
||||||
|
self.sequence_number += 1
|
||||||
|
yield OpenAIResponseObjectStreamResponseMcpListToolsCompleted(
|
||||||
|
sequence_number=self.sequence_number,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Emit output_item.done for the MCP list tools message
|
||||||
|
self.sequence_number += 1
|
||||||
|
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||||
|
response_id=self.response_id,
|
||||||
|
item=mcp_list_message,
|
||||||
|
output_index=len(output_messages) - 1,
|
||||||
|
sequence_number=self.sequence_number,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# TODO: Emit mcp_list_tools.failed event if needed
|
||||||
|
logger.exception(f"Failed to list MCP tools from {mcp_tool.server_url}: {e}")
|
||||||
|
raise
|
||||||
|
|
@ -0,0 +1,379 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
|
||||||
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
|
OpenAIResponseInputToolFileSearch,
|
||||||
|
OpenAIResponseInputToolMCP,
|
||||||
|
OpenAIResponseObjectStreamResponseMcpCallCompleted,
|
||||||
|
OpenAIResponseObjectStreamResponseMcpCallFailed,
|
||||||
|
OpenAIResponseObjectStreamResponseMcpCallInProgress,
|
||||||
|
OpenAIResponseObjectStreamResponseWebSearchCallCompleted,
|
||||||
|
OpenAIResponseObjectStreamResponseWebSearchCallInProgress,
|
||||||
|
OpenAIResponseObjectStreamResponseWebSearchCallSearching,
|
||||||
|
OpenAIResponseOutputMessageFileSearchToolCall,
|
||||||
|
OpenAIResponseOutputMessageFileSearchToolCallResults,
|
||||||
|
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.common.content_types import (
|
||||||
|
ImageContentItem,
|
||||||
|
TextContentItem,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
|
OpenAIChatCompletionToolCall,
|
||||||
|
OpenAIImageURL,
|
||||||
|
OpenAIToolMessageParam,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
||||||
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from .types import ChatCompletionContext, ToolExecutionResult
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="responses")
|
||||||
|
|
||||||
|
|
||||||
|
class ToolExecutor:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tool_groups_api: ToolGroups,
|
||||||
|
tool_runtime_api: ToolRuntime,
|
||||||
|
vector_io_api: VectorIO,
|
||||||
|
):
|
||||||
|
self.tool_groups_api = tool_groups_api
|
||||||
|
self.tool_runtime_api = tool_runtime_api
|
||||||
|
self.vector_io_api = vector_io_api
|
||||||
|
|
||||||
|
async def execute_tool_call(
|
||||||
|
self,
|
||||||
|
tool_call: OpenAIChatCompletionToolCall,
|
||||||
|
ctx: ChatCompletionContext,
|
||||||
|
sequence_number: int,
|
||||||
|
output_index: int,
|
||||||
|
item_id: str,
|
||||||
|
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||||
|
) -> AsyncIterator[ToolExecutionResult]:
|
||||||
|
tool_call_id = tool_call.id
|
||||||
|
function = tool_call.function
|
||||||
|
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
|
||||||
|
|
||||||
|
if not function or not tool_call_id or not function.name:
|
||||||
|
yield ToolExecutionResult(sequence_number=sequence_number)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Emit progress events for tool execution start
|
||||||
|
async for event_result in self._emit_progress_events(
|
||||||
|
function.name, ctx, sequence_number, output_index, item_id, mcp_tool_to_server
|
||||||
|
):
|
||||||
|
sequence_number = event_result.sequence_number
|
||||||
|
yield event_result
|
||||||
|
|
||||||
|
# Execute the actual tool call
|
||||||
|
error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server)
|
||||||
|
|
||||||
|
# Emit completion events for tool execution
|
||||||
|
has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message))
|
||||||
|
async for event_result in self._emit_completion_events(
|
||||||
|
function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server
|
||||||
|
):
|
||||||
|
sequence_number = event_result.sequence_number
|
||||||
|
yield event_result
|
||||||
|
|
||||||
|
# Build result messages from tool execution
|
||||||
|
output_message, input_message = await self._build_result_messages(
|
||||||
|
function, tool_call_id, tool_kwargs, ctx, error_exc, result, has_error, mcp_tool_to_server
|
||||||
|
)
|
||||||
|
|
||||||
|
# Yield the final result
|
||||||
|
yield ToolExecutionResult(
|
||||||
|
sequence_number=sequence_number, final_output_message=output_message, final_input_message=input_message
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _execute_knowledge_search_via_vector_store(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
response_file_search_tool: OpenAIResponseInputToolFileSearch,
|
||||||
|
) -> ToolInvocationResult:
|
||||||
|
"""Execute knowledge search using vector_stores.search API with filters support."""
|
||||||
|
search_results = []
|
||||||
|
|
||||||
|
# Create search tasks for all vector stores
|
||||||
|
async def search_single_store(vector_store_id):
|
||||||
|
try:
|
||||||
|
search_response = await self.vector_io_api.openai_search_vector_store(
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
query=query,
|
||||||
|
filters=response_file_search_tool.filters,
|
||||||
|
max_num_results=response_file_search_tool.max_num_results,
|
||||||
|
ranking_options=response_file_search_tool.ranking_options,
|
||||||
|
rewrite_query=False,
|
||||||
|
)
|
||||||
|
return search_response.data
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to search vector store {vector_store_id}: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Run all searches in parallel using gather
|
||||||
|
search_tasks = [search_single_store(vid) for vid in response_file_search_tool.vector_store_ids]
|
||||||
|
all_results = await asyncio.gather(*search_tasks)
|
||||||
|
|
||||||
|
# Flatten results
|
||||||
|
for results in all_results:
|
||||||
|
search_results.extend(results)
|
||||||
|
|
||||||
|
# Convert search results to tool result format matching memory.py
|
||||||
|
# Format the results as interleaved content similar to memory.py
|
||||||
|
content_items = []
|
||||||
|
content_items.append(
|
||||||
|
TextContentItem(
|
||||||
|
text=f"knowledge_search tool found {len(search_results)} chunks:\nBEGIN of knowledge_search tool results.\n"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, result_item in enumerate(search_results):
|
||||||
|
chunk_text = result_item.content[0].text if result_item.content else ""
|
||||||
|
metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}"
|
||||||
|
if result_item.attributes:
|
||||||
|
metadata_text += f", attributes: {result_item.attributes}"
|
||||||
|
text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n"
|
||||||
|
content_items.append(TextContentItem(text=text_content))
|
||||||
|
|
||||||
|
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
||||||
|
content_items.append(
|
||||||
|
TextContentItem(
|
||||||
|
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n',
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ToolInvocationResult(
|
||||||
|
content=content_items,
|
||||||
|
metadata={
|
||||||
|
"document_ids": [r.file_id for r in search_results],
|
||||||
|
"chunks": [r.content[0].text if r.content else "" for r in search_results],
|
||||||
|
"scores": [r.score for r in search_results],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _emit_progress_events(
|
||||||
|
self,
|
||||||
|
function_name: str,
|
||||||
|
ctx: ChatCompletionContext,
|
||||||
|
sequence_number: int,
|
||||||
|
output_index: int,
|
||||||
|
item_id: str,
|
||||||
|
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||||
|
) -> AsyncIterator[ToolExecutionResult]:
|
||||||
|
"""Emit progress events for tool execution start."""
|
||||||
|
# Emit in_progress event based on tool type (only for tools with specific streaming events)
|
||||||
|
progress_event = None
|
||||||
|
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||||
|
sequence_number += 1
|
||||||
|
progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress(
|
||||||
|
item_id=item_id,
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=sequence_number,
|
||||||
|
)
|
||||||
|
elif function_name == "web_search":
|
||||||
|
sequence_number += 1
|
||||||
|
progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
|
||||||
|
item_id=item_id,
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=sequence_number,
|
||||||
|
)
|
||||||
|
# Note: knowledge_search and other custom tools don't have specific streaming events in OpenAI spec
|
||||||
|
|
||||||
|
if progress_event:
|
||||||
|
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
|
||||||
|
|
||||||
|
# For web search, emit searching event
|
||||||
|
if function_name == "web_search":
|
||||||
|
sequence_number += 1
|
||||||
|
searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching(
|
||||||
|
item_id=item_id,
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=sequence_number,
|
||||||
|
)
|
||||||
|
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
||||||
|
|
||||||
|
async def _execute_tool(
|
||||||
|
self,
|
||||||
|
function_name: str,
|
||||||
|
tool_kwargs: dict,
|
||||||
|
ctx: ChatCompletionContext,
|
||||||
|
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||||
|
) -> tuple[Exception | None, any]:
|
||||||
|
"""Execute the tool and return error exception and result."""
|
||||||
|
error_exc = None
|
||||||
|
result = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||||
|
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool
|
||||||
|
|
||||||
|
mcp_tool = mcp_tool_to_server[function_name]
|
||||||
|
result = await invoke_mcp_tool(
|
||||||
|
endpoint=mcp_tool.server_url,
|
||||||
|
headers=mcp_tool.headers or {},
|
||||||
|
tool_name=function_name,
|
||||||
|
kwargs=tool_kwargs,
|
||||||
|
)
|
||||||
|
elif function_name == "knowledge_search":
|
||||||
|
response_file_search_tool = next(
|
||||||
|
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if response_file_search_tool:
|
||||||
|
# Use vector_stores.search API instead of knowledge_search tool
|
||||||
|
# to support filters and ranking_options
|
||||||
|
query = tool_kwargs.get("query", "")
|
||||||
|
result = await self._execute_knowledge_search_via_vector_store(
|
||||||
|
query=query,
|
||||||
|
response_file_search_tool=response_file_search_tool,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = await self.tool_runtime_api.invoke_tool(
|
||||||
|
tool_name=function_name,
|
||||||
|
kwargs=tool_kwargs,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
error_exc = e
|
||||||
|
|
||||||
|
return error_exc, result
|
||||||
|
|
||||||
|
async def _emit_completion_events(
|
||||||
|
self,
|
||||||
|
function_name: str,
|
||||||
|
ctx: ChatCompletionContext,
|
||||||
|
sequence_number: int,
|
||||||
|
output_index: int,
|
||||||
|
item_id: str,
|
||||||
|
has_error: bool,
|
||||||
|
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||||
|
) -> AsyncIterator[ToolExecutionResult]:
|
||||||
|
"""Emit completion or failure events for tool execution."""
|
||||||
|
completion_event = None
|
||||||
|
|
||||||
|
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||||
|
sequence_number += 1
|
||||||
|
if has_error:
|
||||||
|
completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
|
||||||
|
sequence_number=sequence_number,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
|
||||||
|
sequence_number=sequence_number,
|
||||||
|
)
|
||||||
|
elif function_name == "web_search":
|
||||||
|
sequence_number += 1
|
||||||
|
completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
|
||||||
|
item_id=item_id,
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=sequence_number,
|
||||||
|
)
|
||||||
|
# Note: knowledge_search and other custom tools don't have specific completion events in OpenAI spec
|
||||||
|
|
||||||
|
if completion_event:
|
||||||
|
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)
|
||||||
|
|
||||||
|
async def _build_result_messages(
|
||||||
|
self,
|
||||||
|
function,
|
||||||
|
tool_call_id: str,
|
||||||
|
tool_kwargs: dict,
|
||||||
|
ctx: ChatCompletionContext,
|
||||||
|
error_exc: Exception | None,
|
||||||
|
result: any,
|
||||||
|
has_error: bool,
|
||||||
|
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||||
|
) -> tuple[any, any]:
|
||||||
|
"""Build output and input messages from tool execution results."""
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
interleaved_content_as_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build output message
|
||||||
|
if mcp_tool_to_server and function.name in mcp_tool_to_server:
|
||||||
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
|
OpenAIResponseOutputMessageMCPCall,
|
||||||
|
)
|
||||||
|
|
||||||
|
message = OpenAIResponseOutputMessageMCPCall(
|
||||||
|
id=tool_call_id,
|
||||||
|
arguments=function.arguments,
|
||||||
|
name=function.name,
|
||||||
|
server_label=mcp_tool_to_server[function.name].server_label,
|
||||||
|
)
|
||||||
|
if error_exc:
|
||||||
|
message.error = str(error_exc)
|
||||||
|
elif (result and result.error_code and result.error_code > 0) or (result and result.error_message):
|
||||||
|
message.error = f"Error (code {result.error_code}): {result.error_message}"
|
||||||
|
elif result and result.content:
|
||||||
|
message.output = interleaved_content_as_str(result.content)
|
||||||
|
else:
|
||||||
|
if function.name == "web_search":
|
||||||
|
message = OpenAIResponseOutputMessageWebSearchToolCall(
|
||||||
|
id=tool_call_id,
|
||||||
|
status="completed",
|
||||||
|
)
|
||||||
|
if has_error:
|
||||||
|
message.status = "failed"
|
||||||
|
elif function.name == "knowledge_search":
|
||||||
|
message = OpenAIResponseOutputMessageFileSearchToolCall(
|
||||||
|
id=tool_call_id,
|
||||||
|
queries=[tool_kwargs.get("query", "")],
|
||||||
|
status="completed",
|
||||||
|
)
|
||||||
|
if result and "document_ids" in result.metadata:
|
||||||
|
message.results = []
|
||||||
|
for i, doc_id in enumerate(result.metadata["document_ids"]):
|
||||||
|
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
|
||||||
|
score = result.metadata["scores"][i] if "scores" in result.metadata else None
|
||||||
|
message.results.append(
|
||||||
|
OpenAIResponseOutputMessageFileSearchToolCallResults(
|
||||||
|
file_id=doc_id,
|
||||||
|
filename=doc_id,
|
||||||
|
text=text,
|
||||||
|
score=score,
|
||||||
|
attributes={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if has_error:
|
||||||
|
message.status = "failed"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown tool {function.name} called")
|
||||||
|
|
||||||
|
# Build input message
|
||||||
|
input_message = None
|
||||||
|
if result and result.content:
|
||||||
|
if isinstance(result.content, str):
|
||||||
|
content = result.content
|
||||||
|
elif isinstance(result.content, list):
|
||||||
|
content = []
|
||||||
|
for item in result.content:
|
||||||
|
if isinstance(item, TextContentItem):
|
||||||
|
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
|
||||||
|
elif isinstance(item, ImageContentItem):
|
||||||
|
if item.image.data:
|
||||||
|
url = f"data:image;base64,{item.image.data}"
|
||||||
|
else:
|
||||||
|
url = item.image.url
|
||||||
|
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown result content type: {type(item)}")
|
||||||
|
content.append(part)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown result content type: {type(result.content)}")
|
||||||
|
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
|
||||||
|
else:
|
||||||
|
text = str(error_exc) if error_exc else "Tool execution failed"
|
||||||
|
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
|
||||||
|
|
||||||
|
return message, input_message
|
||||||
|
|
@ -0,0 +1,60 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from openai.types.chat import ChatCompletionToolParam
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
|
OpenAIResponseInputTool,
|
||||||
|
OpenAIResponseObjectStream,
|
||||||
|
OpenAIResponseOutput,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.inference import OpenAIChatCompletionToolCall, OpenAIMessageParam, OpenAIResponseFormatParam
|
||||||
|
|
||||||
|
|
||||||
|
class ToolExecutionResult(BaseModel):
|
||||||
|
"""Result of streaming tool execution."""
|
||||||
|
|
||||||
|
stream_event: OpenAIResponseObjectStream | None = None
|
||||||
|
sequence_number: int
|
||||||
|
final_output_message: OpenAIResponseOutput | None = None
|
||||||
|
final_input_message: OpenAIMessageParam | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatCompletionResult:
|
||||||
|
"""Result of processing streaming chat completion chunks."""
|
||||||
|
|
||||||
|
response_id: str
|
||||||
|
content: list[str]
|
||||||
|
tool_calls: dict[int, OpenAIChatCompletionToolCall]
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
finish_reason: str
|
||||||
|
message_item_id: str # For streaming events
|
||||||
|
tool_call_item_ids: dict[int, str] # For streaming events
|
||||||
|
content_part_emitted: bool # Tracking state
|
||||||
|
|
||||||
|
@property
|
||||||
|
def content_text(self) -> str:
|
||||||
|
"""Get joined content as string."""
|
||||||
|
return "".join(self.content)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_tool_calls(self) -> bool:
|
||||||
|
"""Check if there are any tool calls."""
|
||||||
|
return bool(self.tool_calls)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionContext(BaseModel):
|
||||||
|
model: str
|
||||||
|
messages: list[OpenAIMessageParam]
|
||||||
|
response_tools: list[OpenAIResponseInputTool] | None = None
|
||||||
|
chat_tools: list[ChatCompletionToolParam] | None = None
|
||||||
|
temperature: float | None
|
||||||
|
response_format: OpenAIResponseFormatParam
|
||||||
|
|
@ -0,0 +1,169 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
|
OpenAIResponseInput,
|
||||||
|
OpenAIResponseInputFunctionToolCallOutput,
|
||||||
|
OpenAIResponseInputMessageContent,
|
||||||
|
OpenAIResponseInputMessageContentImage,
|
||||||
|
OpenAIResponseInputMessageContentText,
|
||||||
|
OpenAIResponseInputTool,
|
||||||
|
OpenAIResponseMessage,
|
||||||
|
OpenAIResponseOutputMessageContent,
|
||||||
|
OpenAIResponseOutputMessageContentOutputText,
|
||||||
|
OpenAIResponseOutputMessageFunctionToolCall,
|
||||||
|
OpenAIResponseText,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
OpenAIAssistantMessageParam,
|
||||||
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
|
OpenAIChatCompletionContentPartParam,
|
||||||
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
|
OpenAIChatCompletionToolCall,
|
||||||
|
OpenAIChatCompletionToolCallFunction,
|
||||||
|
OpenAIChoice,
|
||||||
|
OpenAIDeveloperMessageParam,
|
||||||
|
OpenAIImageURL,
|
||||||
|
OpenAIJSONSchema,
|
||||||
|
OpenAIMessageParam,
|
||||||
|
OpenAIResponseFormatJSONObject,
|
||||||
|
OpenAIResponseFormatJSONSchema,
|
||||||
|
OpenAIResponseFormatParam,
|
||||||
|
OpenAIResponseFormatText,
|
||||||
|
OpenAISystemMessageParam,
|
||||||
|
OpenAIToolMessageParam,
|
||||||
|
OpenAIUserMessageParam,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
|
||||||
|
"""Convert an OpenAI Chat Completion choice into an OpenAI Response output message."""
|
||||||
|
output_content = ""
|
||||||
|
if isinstance(choice.message.content, str):
|
||||||
|
output_content = choice.message.content
|
||||||
|
elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam):
|
||||||
|
output_content = choice.message.content.text
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return OpenAIResponseMessage(
|
||||||
|
id=f"msg_{uuid.uuid4()}",
|
||||||
|
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
|
||||||
|
status="completed",
|
||||||
|
role="assistant",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def convert_response_content_to_chat_content(
|
||||||
|
content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
|
||||||
|
) -> str | list[OpenAIChatCompletionContentPartParam]:
|
||||||
|
"""
|
||||||
|
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
|
||||||
|
|
||||||
|
The content schemas of each API look similar, but are not exactly the same.
|
||||||
|
"""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
|
||||||
|
converted_parts = []
|
||||||
|
for content_part in content:
|
||||||
|
if isinstance(content_part, OpenAIResponseInputMessageContentText):
|
||||||
|
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||||
|
elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText):
|
||||||
|
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||||
|
elif isinstance(content_part, OpenAIResponseInputMessageContentImage):
|
||||||
|
if content_part.image_url:
|
||||||
|
image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail)
|
||||||
|
converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
|
||||||
|
elif isinstance(content_part, str):
|
||||||
|
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part))
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Llama Stack OpenAI Responses does not yet support content type '{type(content_part)}' in this context"
|
||||||
|
)
|
||||||
|
return converted_parts
|
||||||
|
|
||||||
|
|
||||||
|
async def convert_response_input_to_chat_messages(
|
||||||
|
input: str | list[OpenAIResponseInput],
|
||||||
|
) -> list[OpenAIMessageParam]:
|
||||||
|
"""
|
||||||
|
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
|
||||||
|
"""
|
||||||
|
messages: list[OpenAIMessageParam] = []
|
||||||
|
if isinstance(input, list):
|
||||||
|
for input_item in input:
|
||||||
|
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
||||||
|
messages.append(
|
||||||
|
OpenAIToolMessageParam(
|
||||||
|
content=input_item.output,
|
||||||
|
tool_call_id=input_item.call_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
|
||||||
|
tool_call = OpenAIChatCompletionToolCall(
|
||||||
|
index=0,
|
||||||
|
id=input_item.call_id,
|
||||||
|
function=OpenAIChatCompletionToolCallFunction(
|
||||||
|
name=input_item.name,
|
||||||
|
arguments=input_item.arguments,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||||
|
else:
|
||||||
|
content = await convert_response_content_to_chat_content(input_item.content)
|
||||||
|
message_type = await get_message_type_by_role(input_item.role)
|
||||||
|
if message_type is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
||||||
|
)
|
||||||
|
messages.append(message_type(content=content))
|
||||||
|
else:
|
||||||
|
messages.append(OpenAIUserMessageParam(content=input))
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
async def convert_response_text_to_chat_response_format(
|
||||||
|
text: OpenAIResponseText,
|
||||||
|
) -> OpenAIResponseFormatParam:
|
||||||
|
"""
|
||||||
|
Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format.
|
||||||
|
"""
|
||||||
|
if not text.format or text.format["type"] == "text":
|
||||||
|
return OpenAIResponseFormatText(type="text")
|
||||||
|
if text.format["type"] == "json_object":
|
||||||
|
return OpenAIResponseFormatJSONObject()
|
||||||
|
if text.format["type"] == "json_schema":
|
||||||
|
return OpenAIResponseFormatJSONSchema(
|
||||||
|
json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"])
|
||||||
|
)
|
||||||
|
raise ValueError(f"Unsupported text format: {text.format}")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_message_type_by_role(role: str):
|
||||||
|
role_to_type = {
|
||||||
|
"user": OpenAIUserMessageParam,
|
||||||
|
"system": OpenAISystemMessageParam,
|
||||||
|
"assistant": OpenAIAssistantMessageParam,
|
||||||
|
"developer": OpenAIDeveloperMessageParam,
|
||||||
|
}
|
||||||
|
return role_to_type.get(role)
|
||||||
|
|
||||||
|
|
||||||
|
def is_function_tool_call(
|
||||||
|
tool_call: OpenAIChatCompletionToolCall,
|
||||||
|
tools: list[OpenAIResponseInputTool],
|
||||||
|
) -> bool:
|
||||||
|
if not tool_call.function:
|
||||||
|
return False
|
||||||
|
for t in tools:
|
||||||
|
if t.type == "function" and t.name == tool_call.function.name:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
5
llama_stack/providers/inline/batches/__init__.py
Normal file
5
llama_stack/providers/inline/batches/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
36
llama_stack/providers/inline/batches/reference/__init__.py
Normal file
36
llama_stack/providers/inline/batches/reference/__init__.py
Normal file
|
|
@ -0,0 +1,36 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.files import Files
|
||||||
|
from llama_stack.apis.inference import Inference
|
||||||
|
from llama_stack.apis.models import Models
|
||||||
|
from llama_stack.core.datatypes import AccessRule, Api
|
||||||
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
|
||||||
|
from .batches import ReferenceBatchesImpl
|
||||||
|
from .config import ReferenceBatchesImplConfig
|
||||||
|
|
||||||
|
__all__ = ["ReferenceBatchesImpl", "ReferenceBatchesImplConfig"]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: ReferenceBatchesImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
|
||||||
|
kvstore = await kvstore_impl(config.kvstore)
|
||||||
|
inference_api: Inference | None = deps.get(Api.inference)
|
||||||
|
files_api: Files | None = deps.get(Api.files)
|
||||||
|
models_api: Models | None = deps.get(Api.models)
|
||||||
|
|
||||||
|
if inference_api is None:
|
||||||
|
raise ValueError("Inference API is required but not provided in dependencies")
|
||||||
|
if files_api is None:
|
||||||
|
raise ValueError("Files API is required but not provided in dependencies")
|
||||||
|
if models_api is None:
|
||||||
|
raise ValueError("Models API is required but not provided in dependencies")
|
||||||
|
|
||||||
|
impl = ReferenceBatchesImpl(config, inference_api, files_api, models_api, kvstore)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
580
llama_stack/providers/inline/batches/reference/batches.py
Normal file
580
llama_stack/providers/inline/batches/reference/batches.py
Normal file
|
|
@ -0,0 +1,580 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import itertools
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from openai.types.batch import BatchError, Errors
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.apis.batches import Batches, BatchObject, ListBatchesResponse
|
||||||
|
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
||||||
|
from llama_stack.apis.files import Files, OpenAIFilePurpose
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
Inference,
|
||||||
|
OpenAIAssistantMessageParam,
|
||||||
|
OpenAIDeveloperMessageParam,
|
||||||
|
OpenAIMessageParam,
|
||||||
|
OpenAISystemMessageParam,
|
||||||
|
OpenAIToolMessageParam,
|
||||||
|
OpenAIUserMessageParam,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.models import Models
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
|
from .config import ReferenceBatchesImplConfig
|
||||||
|
|
||||||
|
BATCH_PREFIX = "batch:"
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncBytesIO:
|
||||||
|
"""
|
||||||
|
Async-compatible BytesIO wrapper to allow async file-like operations.
|
||||||
|
|
||||||
|
We use this when uploading files to the Files API, as it expects an
|
||||||
|
async file-like object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data: bytes):
|
||||||
|
self._buffer = BytesIO(data)
|
||||||
|
|
||||||
|
async def read(self, n=-1):
|
||||||
|
return self._buffer.read(n)
|
||||||
|
|
||||||
|
async def seek(self, pos, whence=0):
|
||||||
|
return self._buffer.seek(pos, whence)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self._buffer.close()
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self._buffer, name)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchRequest(BaseModel):
|
||||||
|
line_num: int
|
||||||
|
custom_id: str
|
||||||
|
method: str
|
||||||
|
url: str
|
||||||
|
body: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_openai_message_param(msg: dict[str, Any]) -> OpenAIMessageParam:
|
||||||
|
"""Convert a message dictionary to OpenAIMessageParam based on role."""
|
||||||
|
role = msg.get("role")
|
||||||
|
|
||||||
|
if role == "user":
|
||||||
|
return OpenAIUserMessageParam(**msg)
|
||||||
|
elif role == "system":
|
||||||
|
return OpenAISystemMessageParam(**msg)
|
||||||
|
elif role == "assistant":
|
||||||
|
return OpenAIAssistantMessageParam(**msg)
|
||||||
|
elif role == "tool":
|
||||||
|
return OpenAIToolMessageParam(**msg)
|
||||||
|
elif role == "developer":
|
||||||
|
return OpenAIDeveloperMessageParam(**msg)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown message role: {role}")
|
||||||
|
|
||||||
|
|
||||||
|
class ReferenceBatchesImpl(Batches):
|
||||||
|
"""Reference implementation of the Batches API.
|
||||||
|
|
||||||
|
This implementation processes batch files by making individual requests
|
||||||
|
to the inference API and generates output files with results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: ReferenceBatchesImplConfig,
|
||||||
|
inference_api: Inference,
|
||||||
|
files_api: Files,
|
||||||
|
models_api: Models,
|
||||||
|
kvstore: KVStore,
|
||||||
|
) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.kvstore = kvstore
|
||||||
|
self.inference_api = inference_api
|
||||||
|
self.files_api = files_api
|
||||||
|
self.models_api = models_api
|
||||||
|
self._processing_tasks: dict[str, asyncio.Task] = {}
|
||||||
|
self._batch_semaphore = asyncio.Semaphore(config.max_concurrent_batches)
|
||||||
|
self._update_batch_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
# this is to allow tests to disable background processing
|
||||||
|
self.process_batches = True
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
# TODO: start background processing of existing tasks
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
"""Shutdown the batches provider."""
|
||||||
|
if self._processing_tasks:
|
||||||
|
# don't cancel tasks - just let them stop naturally on shutdown
|
||||||
|
# cancelling would mark batches as "cancelled" in the database
|
||||||
|
logger.info(f"Shutdown initiated with {len(self._processing_tasks)} active batch processing tasks")
|
||||||
|
|
||||||
|
# TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions
|
||||||
|
async def create_batch(
|
||||||
|
self,
|
||||||
|
input_file_id: str,
|
||||||
|
endpoint: str,
|
||||||
|
completion_window: Literal["24h"],
|
||||||
|
metadata: dict[str, str] | None = None,
|
||||||
|
) -> BatchObject:
|
||||||
|
"""
|
||||||
|
Create a new batch for processing multiple API requests.
|
||||||
|
|
||||||
|
Error handling by levels -
|
||||||
|
0. Input param handling, results in 40x errors before processing, e.g.
|
||||||
|
- Wrong completion_window
|
||||||
|
- Invalid metadata types
|
||||||
|
- Unknown endpoint
|
||||||
|
-> no batch created
|
||||||
|
1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
|
||||||
|
- input_file_id missing
|
||||||
|
- invalid json in file
|
||||||
|
- missing custom_id, method, url, body
|
||||||
|
- invalid model
|
||||||
|
- streaming
|
||||||
|
-> batch created, validation sends to failed status
|
||||||
|
2. Processing errors, result in error_file_id entries, e.g.
|
||||||
|
- Any error returned from inference endpoint
|
||||||
|
-> batch created, goes to completed status
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO: set expiration time for garbage collection
|
||||||
|
|
||||||
|
if endpoint not in ["/v1/chat/completions"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint",
|
||||||
|
)
|
||||||
|
|
||||||
|
if completion_window != "24h":
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window",
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_id = f"batch_{uuid.uuid4().hex[:16]}"
|
||||||
|
current_time = int(time.time())
|
||||||
|
|
||||||
|
batch = BatchObject(
|
||||||
|
id=batch_id,
|
||||||
|
object="batch",
|
||||||
|
endpoint=endpoint,
|
||||||
|
input_file_id=input_file_id,
|
||||||
|
completion_window=completion_window,
|
||||||
|
status="validating",
|
||||||
|
created_at=current_time,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
|
||||||
|
|
||||||
|
if self.process_batches:
|
||||||
|
task = asyncio.create_task(self._process_batch(batch_id))
|
||||||
|
self._processing_tasks[batch_id] = task
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
async def cancel_batch(self, batch_id: str) -> BatchObject:
|
||||||
|
"""Cancel a batch that is in progress."""
|
||||||
|
batch = await self.retrieve_batch(batch_id)
|
||||||
|
|
||||||
|
if batch.status in ["cancelled", "cancelling"]:
|
||||||
|
return batch
|
||||||
|
|
||||||
|
if batch.status in ["completed", "failed", "expired"]:
|
||||||
|
raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'")
|
||||||
|
|
||||||
|
await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time()))
|
||||||
|
|
||||||
|
if batch_id in self._processing_tasks:
|
||||||
|
self._processing_tasks[batch_id].cancel()
|
||||||
|
# note: task removal and status="cancelled" handled in finally block of _process_batch
|
||||||
|
|
||||||
|
return await self.retrieve_batch(batch_id)
|
||||||
|
|
||||||
|
async def list_batches(
|
||||||
|
self,
|
||||||
|
after: str | None = None,
|
||||||
|
limit: int = 20,
|
||||||
|
) -> ListBatchesResponse:
|
||||||
|
"""
|
||||||
|
List all batches, eventually only for the current user.
|
||||||
|
|
||||||
|
With no notion of user, we return all batches.
|
||||||
|
"""
|
||||||
|
batch_values = await self.kvstore.values_in_range("batch:", "batch:\xff")
|
||||||
|
|
||||||
|
batches = []
|
||||||
|
for batch_data in batch_values:
|
||||||
|
if batch_data:
|
||||||
|
batches.append(BatchObject.model_validate_json(batch_data))
|
||||||
|
|
||||||
|
batches.sort(key=lambda b: b.created_at, reverse=True)
|
||||||
|
|
||||||
|
start_idx = 0
|
||||||
|
if after:
|
||||||
|
for i, batch in enumerate(batches):
|
||||||
|
if batch.id == after:
|
||||||
|
start_idx = i + 1
|
||||||
|
break
|
||||||
|
|
||||||
|
page_batches = batches[start_idx : start_idx + limit]
|
||||||
|
has_more = (start_idx + limit) < len(batches)
|
||||||
|
|
||||||
|
first_id = page_batches[0].id if page_batches else None
|
||||||
|
last_id = page_batches[-1].id if page_batches else None
|
||||||
|
|
||||||
|
return ListBatchesResponse(
|
||||||
|
data=page_batches,
|
||||||
|
first_id=first_id,
|
||||||
|
last_id=last_id,
|
||||||
|
has_more=has_more,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def retrieve_batch(self, batch_id: str) -> BatchObject:
|
||||||
|
"""Retrieve information about a specific batch."""
|
||||||
|
batch_data = await self.kvstore.get(f"batch:{batch_id}")
|
||||||
|
if not batch_data:
|
||||||
|
raise ResourceNotFoundError(batch_id, "Batch", "batches.list()")
|
||||||
|
|
||||||
|
return BatchObject.model_validate_json(batch_data)
|
||||||
|
|
||||||
|
async def _update_batch(self, batch_id: str, **updates) -> None:
|
||||||
|
"""Update batch fields in kvstore."""
|
||||||
|
async with self._update_batch_lock:
|
||||||
|
try:
|
||||||
|
batch = await self.retrieve_batch(batch_id)
|
||||||
|
|
||||||
|
# batch processing is async. once cancelling, only allow "cancelled" status updates
|
||||||
|
if batch.status == "cancelling" and updates.get("status") != "cancelled":
|
||||||
|
logger.info(
|
||||||
|
f"Skipping status update for cancelled batch {batch_id}: attempted {updates.get('status')}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if "errors" in updates:
|
||||||
|
updates["errors"] = updates["errors"].model_dump()
|
||||||
|
|
||||||
|
batch_dict = batch.model_dump()
|
||||||
|
batch_dict.update(updates)
|
||||||
|
|
||||||
|
await self.kvstore.set(f"batch:{batch_id}", json.dumps(batch_dict))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to update batch {batch_id}: {e}")
|
||||||
|
|
||||||
|
async def _validate_input(self, batch: BatchObject) -> tuple[list[BatchError], list[BatchRequest]]:
|
||||||
|
"""
|
||||||
|
Read & validate input, return errors and valid input.
|
||||||
|
|
||||||
|
Validation of
|
||||||
|
- input_file_id existance
|
||||||
|
- valid json
|
||||||
|
- custom_id, method, url, body presence and valid
|
||||||
|
- no streaming
|
||||||
|
"""
|
||||||
|
requests: list[BatchRequest] = []
|
||||||
|
errors: list[BatchError] = []
|
||||||
|
try:
|
||||||
|
await self.files_api.openai_retrieve_file(batch.input_file_id)
|
||||||
|
except Exception:
|
||||||
|
errors.append(
|
||||||
|
BatchError(
|
||||||
|
code="invalid_request",
|
||||||
|
line=None,
|
||||||
|
message=f"Cannot find file {batch.input_file_id}.",
|
||||||
|
param="input_file_id",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return errors, requests
|
||||||
|
|
||||||
|
# TODO(SECURITY): do something about large files
|
||||||
|
file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id)
|
||||||
|
file_content = file_content_response.body.decode("utf-8")
|
||||||
|
for line_num, line in enumerate(file_content.strip().split("\n"), 1):
|
||||||
|
if line.strip(): # skip empty lines
|
||||||
|
try:
|
||||||
|
request = json.loads(line)
|
||||||
|
|
||||||
|
if not isinstance(request, dict):
|
||||||
|
errors.append(
|
||||||
|
BatchError(
|
||||||
|
code="invalid_request",
|
||||||
|
line=line_num,
|
||||||
|
message="Each line must be a JSON dictionary object",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
valid = True
|
||||||
|
|
||||||
|
for param, expected_type, type_string in [
|
||||||
|
("custom_id", str, "string"),
|
||||||
|
("method", str, "string"),
|
||||||
|
("url", str, "string"),
|
||||||
|
("body", dict, "JSON dictionary object"),
|
||||||
|
]:
|
||||||
|
if param not in request:
|
||||||
|
errors.append(
|
||||||
|
BatchError(
|
||||||
|
code="missing_required_parameter",
|
||||||
|
line=line_num,
|
||||||
|
message=f"Missing required parameter: {param}",
|
||||||
|
param=param,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
valid = False
|
||||||
|
elif not isinstance(request[param], expected_type):
|
||||||
|
param_name = "URL" if param == "url" else param.capitalize()
|
||||||
|
errors.append(
|
||||||
|
BatchError(
|
||||||
|
code="invalid_request",
|
||||||
|
line=line_num,
|
||||||
|
message=f"{param_name} must be a {type_string}",
|
||||||
|
param=param,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
valid = False
|
||||||
|
|
||||||
|
if (url := request.get("url")) and isinstance(url, str) and url != batch.endpoint:
|
||||||
|
errors.append(
|
||||||
|
BatchError(
|
||||||
|
code="invalid_url",
|
||||||
|
line=line_num,
|
||||||
|
message="URL provided for this request does not match the batch endpoint",
|
||||||
|
param="url",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
valid = False
|
||||||
|
|
||||||
|
if (body := request.get("body")) and isinstance(body, dict):
|
||||||
|
if body.get("stream", False):
|
||||||
|
errors.append(
|
||||||
|
BatchError(
|
||||||
|
code="streaming_unsupported",
|
||||||
|
line=line_num,
|
||||||
|
message="Streaming is not supported in batch processing",
|
||||||
|
param="body.stream",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
valid = False
|
||||||
|
|
||||||
|
for param, expected_type, type_string in [
|
||||||
|
("model", str, "a string"),
|
||||||
|
# messages is specific to /v1/chat/completions
|
||||||
|
# we could skip validating messages here and let inference fail. however,
|
||||||
|
# that would be a very expensive way to find out messages is wrong.
|
||||||
|
("messages", list, "an array"), # TODO: allow messages to be a string?
|
||||||
|
]:
|
||||||
|
if param not in body:
|
||||||
|
errors.append(
|
||||||
|
BatchError(
|
||||||
|
code="invalid_request",
|
||||||
|
line=line_num,
|
||||||
|
message=f"{param.capitalize()} parameter is required",
|
||||||
|
param=f"body.{param}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
valid = False
|
||||||
|
elif not isinstance(body[param], expected_type):
|
||||||
|
errors.append(
|
||||||
|
BatchError(
|
||||||
|
code="invalid_request",
|
||||||
|
line=line_num,
|
||||||
|
message=f"{param.capitalize()} must be {type_string}",
|
||||||
|
param=f"body.{param}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
valid = False
|
||||||
|
|
||||||
|
if "model" in body and isinstance(body["model"], str):
|
||||||
|
try:
|
||||||
|
await self.models_api.get_model(body["model"])
|
||||||
|
except Exception:
|
||||||
|
errors.append(
|
||||||
|
BatchError(
|
||||||
|
code="model_not_found",
|
||||||
|
line=line_num,
|
||||||
|
message=f"Model '{body['model']}' does not exist or is not supported",
|
||||||
|
param="body.model",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
valid = False
|
||||||
|
|
||||||
|
if valid:
|
||||||
|
assert isinstance(url, str), "URL must be a string" # for mypy
|
||||||
|
assert isinstance(body, dict), "Body must be a dictionary" # for mypy
|
||||||
|
requests.append(
|
||||||
|
BatchRequest(
|
||||||
|
line_num=line_num,
|
||||||
|
url=url,
|
||||||
|
method=request["method"],
|
||||||
|
custom_id=request["custom_id"],
|
||||||
|
body=body,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
errors.append(
|
||||||
|
BatchError(
|
||||||
|
code="invalid_json_line",
|
||||||
|
line=line_num,
|
||||||
|
message="This line is not parseable as valid JSON.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return errors, requests
|
||||||
|
|
||||||
|
async def _process_batch(self, batch_id: str) -> None:
|
||||||
|
"""Background task to process a batch of requests."""
|
||||||
|
try:
|
||||||
|
logger.info(f"Starting batch processing for {batch_id}")
|
||||||
|
async with self._batch_semaphore: # semaphore to limit concurrency
|
||||||
|
logger.info(f"Acquired semaphore for batch {batch_id}")
|
||||||
|
await self._process_batch_impl(batch_id)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info(f"Batch processing cancelled for {batch_id}")
|
||||||
|
await self._update_batch(batch_id, status="cancelled", cancelled_at=int(time.time()))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Batch processing failed for {batch_id}: {e}")
|
||||||
|
await self._update_batch(
|
||||||
|
batch_id,
|
||||||
|
status="failed",
|
||||||
|
failed_at=int(time.time()),
|
||||||
|
errors=Errors(data=[BatchError(code="internal_error", message=str(e))]),
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
self._processing_tasks.pop(batch_id, None)
|
||||||
|
|
||||||
|
async def _process_batch_impl(self, batch_id: str) -> None:
|
||||||
|
"""Implementation of batch processing logic."""
|
||||||
|
errors: list[BatchError] = []
|
||||||
|
batch = await self.retrieve_batch(batch_id)
|
||||||
|
|
||||||
|
errors, requests = await self._validate_input(batch)
|
||||||
|
if errors:
|
||||||
|
await self._update_batch(batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors))
|
||||||
|
logger.info(f"Batch validation failed for {batch_id} with {len(errors)} errors")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Processing {len(requests)} requests for batch {batch_id}")
|
||||||
|
|
||||||
|
total_requests = len(requests)
|
||||||
|
await self._update_batch(
|
||||||
|
batch_id,
|
||||||
|
status="in_progress",
|
||||||
|
request_counts={"total": total_requests, "completed": 0, "failed": 0},
|
||||||
|
)
|
||||||
|
|
||||||
|
error_results = []
|
||||||
|
success_results = []
|
||||||
|
completed_count = 0
|
||||||
|
failed_count = 0
|
||||||
|
|
||||||
|
for chunk in itertools.batched(requests, self.config.max_concurrent_requests_per_batch):
|
||||||
|
# we use a TaskGroup to ensure all process-single-request tasks are canceled when process-batch is cancelled
|
||||||
|
async with asyncio.TaskGroup() as tg:
|
||||||
|
chunk_tasks = [tg.create_task(self._process_single_request(batch_id, request)) for request in chunk]
|
||||||
|
|
||||||
|
chunk_results = await asyncio.gather(*chunk_tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
for result in chunk_results:
|
||||||
|
if isinstance(result, dict) and result.get("error") is not None: # error response from inference
|
||||||
|
failed_count += 1
|
||||||
|
error_results.append(result)
|
||||||
|
elif isinstance(result, dict) and result.get("response") is not None: # successful inference
|
||||||
|
completed_count += 1
|
||||||
|
success_results.append(result)
|
||||||
|
else: # unexpected result
|
||||||
|
failed_count += 1
|
||||||
|
errors.append(BatchError(code="internal_error", message=f"Unexpected result: {result}"))
|
||||||
|
|
||||||
|
await self._update_batch(
|
||||||
|
batch_id,
|
||||||
|
request_counts={"total": total_requests, "completed": completed_count, "failed": failed_count},
|
||||||
|
)
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
await self._update_batch(
|
||||||
|
batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
output_file_id = await self._create_output_file(batch_id, success_results, "success")
|
||||||
|
await self._update_batch(batch_id, output_file_id=output_file_id)
|
||||||
|
|
||||||
|
error_file_id = await self._create_output_file(batch_id, error_results, "error")
|
||||||
|
await self._update_batch(batch_id, error_file_id=error_file_id)
|
||||||
|
|
||||||
|
await self._update_batch(batch_id, status="completed", completed_at=int(time.time()))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Batch processing completed for {batch_id}: {completed_count} completed, {failed_count} failed"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# note: errors is empty at this point, so we don't lose anything by ignoring it
|
||||||
|
await self._update_batch(
|
||||||
|
batch_id,
|
||||||
|
status="failed",
|
||||||
|
failed_at=int(time.time()),
|
||||||
|
errors=Errors(data=[BatchError(code="output_failed", message=str(e))]),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _process_single_request(self, batch_id: str, request: BatchRequest) -> dict:
|
||||||
|
"""Process a single request from the batch."""
|
||||||
|
request_id = f"batch_req_{batch_id}_{request.line_num}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# TODO(SECURITY): review body for security issues
|
||||||
|
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
|
||||||
|
chat_response = await self.inference_api.openai_chat_completion(**request.body)
|
||||||
|
|
||||||
|
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||||
|
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
|
||||||
|
return {
|
||||||
|
"id": request_id,
|
||||||
|
"custom_id": request.custom_id,
|
||||||
|
"response": {
|
||||||
|
"status_code": 200,
|
||||||
|
"request_id": request_id, # TODO: should this be different?
|
||||||
|
"body": chat_response.model_dump_json(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
|
||||||
|
return {
|
||||||
|
"id": request_id,
|
||||||
|
"custom_id": request.custom_id,
|
||||||
|
"error": {"type": "request_failed", "message": str(e)},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _create_output_file(self, batch_id: str, results: list[dict], file_type: str) -> str:
|
||||||
|
"""
|
||||||
|
Create an output file with batch results.
|
||||||
|
|
||||||
|
This function filters results based on the specified file_type
|
||||||
|
and uploads the file to the Files API.
|
||||||
|
"""
|
||||||
|
output_lines = [json.dumps(result) for result in results]
|
||||||
|
|
||||||
|
with AsyncBytesIO("\n".join(output_lines).encode("utf-8")) as file_buffer:
|
||||||
|
file_buffer.filename = f"{batch_id}_{file_type}.jsonl"
|
||||||
|
uploaded_file = await self.files_api.openai_upload_file(file=file_buffer, purpose=OpenAIFilePurpose.BATCH)
|
||||||
|
return uploaded_file.id
|
||||||
40
llama_stack/providers/inline/batches/reference/config.py
Normal file
40
llama_stack/providers/inline/batches/reference/config.py
Normal file
|
|
@ -0,0 +1,40 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ReferenceBatchesImplConfig(BaseModel):
|
||||||
|
"""Configuration for the Reference Batches implementation."""
|
||||||
|
|
||||||
|
kvstore: KVStoreConfig = Field(
|
||||||
|
description="Configuration for the key-value store backend.",
|
||||||
|
)
|
||||||
|
|
||||||
|
max_concurrent_batches: int = Field(
|
||||||
|
default=1,
|
||||||
|
description="Maximum number of concurrent batches to process simultaneously.",
|
||||||
|
ge=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
max_concurrent_requests_per_batch: int = Field(
|
||||||
|
default=10,
|
||||||
|
description="Maximum number of concurrent requests to process per batch.",
|
||||||
|
ge=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: add a max requests per second rate limiter
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, __distro_dir__: str) -> dict:
|
||||||
|
return {
|
||||||
|
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||||
|
__distro_dir__=__distro_dir__,
|
||||||
|
db_name="batches.db",
|
||||||
|
),
|
||||||
|
}
|
||||||
26
llama_stack/providers/registry/batches.py
Normal file
26
llama_stack/providers/registry/batches.py
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
|
def available_providers() -> list[ProviderSpec]:
|
||||||
|
return [
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.batches,
|
||||||
|
provider_type="inline::reference",
|
||||||
|
pip_packages=["openai"],
|
||||||
|
module="llama_stack.providers.inline.batches.reference",
|
||||||
|
config_class="llama_stack.providers.inline.batches.reference.config.ReferenceBatchesImplConfig",
|
||||||
|
api_dependencies=[
|
||||||
|
Api.inference,
|
||||||
|
Api.files,
|
||||||
|
Api.models,
|
||||||
|
],
|
||||||
|
description="Reference implementation of batches API with KVStore persistence.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
@ -31,15 +31,15 @@ from openai.types.chat import (
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
|
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
|
||||||
)
|
)
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionMessageFunctionToolCall as OpenAIChatCompletionMessageFunctionToolCall,
|
||||||
|
)
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
|
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
|
||||||
)
|
)
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
ChatCompletionMessageToolCall,
|
ChatCompletionMessageToolCall,
|
||||||
)
|
)
|
||||||
from openai.types.chat import (
|
|
||||||
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall,
|
|
||||||
)
|
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
|
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
|
||||||
)
|
)
|
||||||
|
|
@ -633,7 +633,7 @@ async def convert_message_to_openai_dict_new(
|
||||||
)
|
)
|
||||||
elif isinstance(message, CompletionMessage):
|
elif isinstance(message, CompletionMessage):
|
||||||
tool_calls = [
|
tool_calls = [
|
||||||
OpenAIChatCompletionMessageToolCall(
|
OpenAIChatCompletionMessageFunctionToolCall(
|
||||||
id=tool.call_id,
|
id=tool.call_id,
|
||||||
function=OpenAIFunction(
|
function=OpenAIFunction(
|
||||||
name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
|
name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
|
||||||
|
|
@ -903,7 +903,7 @@ def _convert_openai_request_response_format(
|
||||||
|
|
||||||
|
|
||||||
def _convert_openai_tool_calls(
|
def _convert_openai_tool_calls(
|
||||||
tool_calls: list[OpenAIChatCompletionMessageToolCall],
|
tool_calls: list[OpenAIChatCompletionMessageFunctionToolCall],
|
||||||
) -> list[ToolCall]:
|
) -> list[ToolCall]:
|
||||||
"""
|
"""
|
||||||
Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall.
|
Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall.
|
||||||
|
|
|
||||||
1
llama_stack/ui/.nvmrc
Normal file
1
llama_stack/ui/.nvmrc
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
22.5.1
|
||||||
|
|
@ -1,3 +1,12 @@
|
||||||
# Ignore artifacts:
|
# Ignore artifacts:
|
||||||
build
|
build
|
||||||
coverage
|
coverage
|
||||||
|
.next
|
||||||
|
node_modules
|
||||||
|
dist
|
||||||
|
*.lock
|
||||||
|
*.log
|
||||||
|
|
||||||
|
# Generated files
|
||||||
|
*.min.js
|
||||||
|
*.min.css
|
||||||
|
|
|
||||||
|
|
@ -1 +1,10 @@
|
||||||
{}
|
{
|
||||||
|
"semi": true,
|
||||||
|
"trailingComma": "es5",
|
||||||
|
"singleQuote": false,
|
||||||
|
"printWidth": 80,
|
||||||
|
"tabWidth": 2,
|
||||||
|
"useTabs": false,
|
||||||
|
"bracketSpacing": true,
|
||||||
|
"arrowParens": "avoid"
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ async function proxyRequest(request: NextRequest, method: string) {
|
||||||
const responseText = await response.text();
|
const responseText = await response.text();
|
||||||
|
|
||||||
console.log(
|
console.log(
|
||||||
`Response from FastAPI: ${response.status} ${response.statusText}`,
|
`Response from FastAPI: ${response.status} ${response.statusText}`
|
||||||
);
|
);
|
||||||
|
|
||||||
// Create response with same status and headers
|
// Create response with same status and headers
|
||||||
|
|
@ -74,7 +74,7 @@ async function proxyRequest(request: NextRequest, method: string) {
|
||||||
backend_url: BACKEND_URL,
|
backend_url: BACKEND_URL,
|
||||||
timestamp: new Date().toISOString(),
|
timestamp: new Date().toISOString(),
|
||||||
},
|
},
|
||||||
{ status: 500 },
|
{ status: 500 }
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -51,9 +51,9 @@ export default function SignInPage() {
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
console.log("Signing in with GitHub...");
|
console.log("Signing in with GitHub...");
|
||||||
signIn("github", { callbackUrl: "/auth/signin" }).catch(
|
signIn("github", { callbackUrl: "/auth/signin" }).catch(
|
||||||
(error) => {
|
error => {
|
||||||
console.error("Sign in error:", error);
|
console.error("Sign in error:", error);
|
||||||
},
|
}
|
||||||
);
|
);
|
||||||
}}
|
}}
|
||||||
className="w-full"
|
className="w-full"
|
||||||
|
|
|
||||||
|
|
@ -29,14 +29,13 @@ export default function ChatPlaygroundPage() {
|
||||||
|
|
||||||
const isModelsLoading = modelsLoading ?? true;
|
const isModelsLoading = modelsLoading ?? true;
|
||||||
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const fetchModels = async () => {
|
const fetchModels = async () => {
|
||||||
try {
|
try {
|
||||||
setModelsLoading(true);
|
setModelsLoading(true);
|
||||||
setModelsError(null);
|
setModelsError(null);
|
||||||
const modelList = await client.models.list();
|
const modelList = await client.models.list();
|
||||||
const llmModels = modelList.filter(model => model.model_type === 'llm');
|
const llmModels = modelList.filter(model => model.model_type === "llm");
|
||||||
setModels(llmModels);
|
setModels(llmModels);
|
||||||
if (llmModels.length > 0) {
|
if (llmModels.length > 0) {
|
||||||
setSelectedModel(llmModels[0].identifier);
|
setSelectedModel(llmModels[0].identifier);
|
||||||
|
|
@ -53,103 +52,122 @@ export default function ChatPlaygroundPage() {
|
||||||
}, [client]);
|
}, [client]);
|
||||||
|
|
||||||
const extractTextContent = (content: unknown): string => {
|
const extractTextContent = (content: unknown): string => {
|
||||||
if (typeof content === 'string') {
|
if (typeof content === "string") {
|
||||||
return content;
|
return content;
|
||||||
}
|
}
|
||||||
if (Array.isArray(content)) {
|
if (Array.isArray(content)) {
|
||||||
return content
|
return content
|
||||||
.filter(item => item && typeof item === 'object' && 'type' in item && item.type === 'text')
|
.filter(
|
||||||
.map(item => (item && typeof item === 'object' && 'text' in item) ? String(item.text) : '')
|
item =>
|
||||||
.join('');
|
item &&
|
||||||
|
typeof item === "object" &&
|
||||||
|
"type" in item &&
|
||||||
|
item.type === "text"
|
||||||
|
)
|
||||||
|
.map(item =>
|
||||||
|
item && typeof item === "object" && "text" in item
|
||||||
|
? String(item.text)
|
||||||
|
: ""
|
||||||
|
)
|
||||||
|
.join("");
|
||||||
}
|
}
|
||||||
if (content && typeof content === 'object' && 'type' in content && content.type === 'text' && 'text' in content) {
|
if (
|
||||||
return String(content.text) || '';
|
content &&
|
||||||
|
typeof content === "object" &&
|
||||||
|
"type" in content &&
|
||||||
|
content.type === "text" &&
|
||||||
|
"text" in content
|
||||||
|
) {
|
||||||
|
return String(content.text) || "";
|
||||||
}
|
}
|
||||||
return '';
|
return "";
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleInputChange = (e: React.ChangeEvent<HTMLTextAreaElement>) => {
|
const handleInputChange = (e: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||||
setInput(e.target.value);
|
setInput(e.target.value);
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleSubmit = async (event?: { preventDefault?: () => void }) => {
|
const handleSubmit = async (event?: { preventDefault?: () => void }) => {
|
||||||
event?.preventDefault?.();
|
event?.preventDefault?.();
|
||||||
if (!input.trim()) return;
|
if (!input.trim()) return;
|
||||||
|
|
||||||
// Add user message to chat
|
// Add user message to chat
|
||||||
const userMessage: Message = {
|
const userMessage: Message = {
|
||||||
id: Date.now().toString(),
|
id: Date.now().toString(),
|
||||||
role: "user",
|
role: "user",
|
||||||
content: input.trim(),
|
content: input.trim(),
|
||||||
createdAt: new Date(),
|
|
||||||
};
|
|
||||||
|
|
||||||
setMessages(prev => [...prev, userMessage]);
|
|
||||||
setInput("");
|
|
||||||
|
|
||||||
// Use the helper function with the content
|
|
||||||
await handleSubmitWithContent(userMessage.content);
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleSubmitWithContent = async (content: string) => {
|
|
||||||
setIsGenerating(true);
|
|
||||||
setError(null);
|
|
||||||
|
|
||||||
try {
|
|
||||||
const messageParams: CompletionCreateParams["messages"] = [
|
|
||||||
...messages.map(msg => {
|
|
||||||
const msgContent = typeof msg.content === 'string' ? msg.content : extractTextContent(msg.content);
|
|
||||||
if (msg.role === "user") {
|
|
||||||
return { role: "user" as const, content: msgContent };
|
|
||||||
} else if (msg.role === "assistant") {
|
|
||||||
return { role: "assistant" as const, content: msgContent };
|
|
||||||
} else {
|
|
||||||
return { role: "system" as const, content: msgContent };
|
|
||||||
}
|
|
||||||
}),
|
|
||||||
{ role: "user" as const, content }
|
|
||||||
];
|
|
||||||
|
|
||||||
const response = await client.chat.completions.create({
|
|
||||||
model: selectedModel,
|
|
||||||
messages: messageParams,
|
|
||||||
stream: true,
|
|
||||||
});
|
|
||||||
|
|
||||||
const assistantMessage: Message = {
|
|
||||||
id: (Date.now() + 1).toString(),
|
|
||||||
role: "assistant",
|
|
||||||
content: "",
|
|
||||||
createdAt: new Date(),
|
createdAt: new Date(),
|
||||||
};
|
};
|
||||||
|
|
||||||
setMessages(prev => [...prev, assistantMessage]);
|
setMessages(prev => [...prev, userMessage]);
|
||||||
let fullContent = "";
|
setInput("");
|
||||||
for await (const chunk of response) {
|
|
||||||
if (chunk.choices && chunk.choices[0]?.delta?.content) {
|
|
||||||
const deltaContent = chunk.choices[0].delta.content;
|
|
||||||
fullContent += deltaContent;
|
|
||||||
|
|
||||||
flushSync(() => {
|
// Use the helper function with the content
|
||||||
setMessages(prev => {
|
await handleSubmitWithContent(userMessage.content);
|
||||||
const newMessages = [...prev];
|
};
|
||||||
const lastMessage = newMessages[newMessages.length - 1];
|
|
||||||
if (lastMessage.role === "assistant") {
|
const handleSubmitWithContent = async (content: string) => {
|
||||||
lastMessage.content = fullContent;
|
setIsGenerating(true);
|
||||||
}
|
setError(null);
|
||||||
return newMessages;
|
|
||||||
|
try {
|
||||||
|
const messageParams: CompletionCreateParams["messages"] = [
|
||||||
|
...messages.map(msg => {
|
||||||
|
const msgContent =
|
||||||
|
typeof msg.content === "string"
|
||||||
|
? msg.content
|
||||||
|
: extractTextContent(msg.content);
|
||||||
|
if (msg.role === "user") {
|
||||||
|
return { role: "user" as const, content: msgContent };
|
||||||
|
} else if (msg.role === "assistant") {
|
||||||
|
return { role: "assistant" as const, content: msgContent };
|
||||||
|
} else {
|
||||||
|
return { role: "system" as const, content: msgContent };
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
{ role: "user" as const, content },
|
||||||
|
];
|
||||||
|
|
||||||
|
const response = await client.chat.completions.create({
|
||||||
|
model: selectedModel,
|
||||||
|
messages: messageParams,
|
||||||
|
stream: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
const assistantMessage: Message = {
|
||||||
|
id: (Date.now() + 1).toString(),
|
||||||
|
role: "assistant",
|
||||||
|
content: "",
|
||||||
|
createdAt: new Date(),
|
||||||
|
};
|
||||||
|
|
||||||
|
setMessages(prev => [...prev, assistantMessage]);
|
||||||
|
let fullContent = "";
|
||||||
|
for await (const chunk of response) {
|
||||||
|
if (chunk.choices && chunk.choices[0]?.delta?.content) {
|
||||||
|
const deltaContent = chunk.choices[0].delta.content;
|
||||||
|
fullContent += deltaContent;
|
||||||
|
|
||||||
|
flushSync(() => {
|
||||||
|
setMessages(prev => {
|
||||||
|
const newMessages = [...prev];
|
||||||
|
const lastMessage = newMessages[newMessages.length - 1];
|
||||||
|
if (lastMessage.role === "assistant") {
|
||||||
|
lastMessage.content = fullContent;
|
||||||
|
}
|
||||||
|
return newMessages;
|
||||||
|
});
|
||||||
});
|
});
|
||||||
});
|
}
|
||||||
}
|
}
|
||||||
|
} catch (err) {
|
||||||
|
console.error("Error sending message:", err);
|
||||||
|
setError("Failed to send message. Please try again.");
|
||||||
|
setMessages(prev => prev.slice(0, -1));
|
||||||
|
} finally {
|
||||||
|
setIsGenerating(false);
|
||||||
}
|
}
|
||||||
} catch (err) {
|
};
|
||||||
console.error("Error sending message:", err);
|
|
||||||
setError("Failed to send message. Please try again.");
|
|
||||||
setMessages(prev => prev.slice(0, -1));
|
|
||||||
} finally {
|
|
||||||
setIsGenerating(false);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
const suggestions = [
|
const suggestions = [
|
||||||
"Write a Python function that prints 'Hello, World!'",
|
"Write a Python function that prints 'Hello, World!'",
|
||||||
"Explain step-by-step how to solve this math problem: If x² + 6x + 9 = 25, what is x?",
|
"Explain step-by-step how to solve this math problem: If x² + 6x + 9 = 25, what is x?",
|
||||||
|
|
@ -163,7 +181,7 @@ const handleSubmitWithContent = async (content: string) => {
|
||||||
content: message.content,
|
content: message.content,
|
||||||
createdAt: new Date(),
|
createdAt: new Date(),
|
||||||
};
|
};
|
||||||
setMessages(prev => [...prev, newMessage])
|
setMessages(prev => [...prev, newMessage]);
|
||||||
handleSubmitWithContent(newMessage.content);
|
handleSubmitWithContent(newMessage.content);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -177,12 +195,20 @@ const handleSubmitWithContent = async (content: string) => {
|
||||||
<div className="mb-4 flex justify-between items-center">
|
<div className="mb-4 flex justify-between items-center">
|
||||||
<h1 className="text-2xl font-bold">Chat Playground (Completions)</h1>
|
<h1 className="text-2xl font-bold">Chat Playground (Completions)</h1>
|
||||||
<div className="flex gap-2">
|
<div className="flex gap-2">
|
||||||
<Select value={selectedModel} onValueChange={setSelectedModel} disabled={isModelsLoading || isGenerating}>
|
<Select
|
||||||
|
value={selectedModel}
|
||||||
|
onValueChange={setSelectedModel}
|
||||||
|
disabled={isModelsLoading || isGenerating}
|
||||||
|
>
|
||||||
<SelectTrigger className="w-[180px]">
|
<SelectTrigger className="w-[180px]">
|
||||||
<SelectValue placeholder={isModelsLoading ? "Loading models..." : "Select Model"} />
|
<SelectValue
|
||||||
|
placeholder={
|
||||||
|
isModelsLoading ? "Loading models..." : "Select Model"
|
||||||
|
}
|
||||||
|
/>
|
||||||
</SelectTrigger>
|
</SelectTrigger>
|
||||||
<SelectContent>
|
<SelectContent>
|
||||||
{models.map((model) => (
|
{models.map(model => (
|
||||||
<SelectItem key={model.identifier} value={model.identifier}>
|
<SelectItem key={model.identifier} value={model.identifier}>
|
||||||
{model.identifier}
|
{model.identifier}
|
||||||
</SelectItem>
|
</SelectItem>
|
||||||
|
|
|
||||||
|
|
@ -33,12 +33,12 @@ export default function ChatCompletionDetailPage() {
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error(
|
console.error(
|
||||||
`Error fetching chat completion detail for ID ${id}:`,
|
`Error fetching chat completion detail for ID ${id}:`,
|
||||||
err,
|
err
|
||||||
);
|
);
|
||||||
setError(
|
setError(
|
||||||
err instanceof Error
|
err instanceof Error
|
||||||
? err
|
? err
|
||||||
: new Error("Failed to fetch completion detail"),
|
: new Error("Failed to fetch completion detail")
|
||||||
);
|
);
|
||||||
} finally {
|
} finally {
|
||||||
setIsLoading(false);
|
setIsLoading(false);
|
||||||
|
|
|
||||||
|
|
@ -13,10 +13,10 @@ export default function ResponseDetailPage() {
|
||||||
const client = useAuthClient();
|
const client = useAuthClient();
|
||||||
|
|
||||||
const [responseDetail, setResponseDetail] = useState<OpenAIResponse | null>(
|
const [responseDetail, setResponseDetail] = useState<OpenAIResponse | null>(
|
||||||
null,
|
null
|
||||||
);
|
);
|
||||||
const [inputItems, setInputItems] = useState<InputItemListResponse | null>(
|
const [inputItems, setInputItems] = useState<InputItemListResponse | null>(
|
||||||
null,
|
null
|
||||||
);
|
);
|
||||||
const [isLoading, setIsLoading] = useState<boolean>(true);
|
const [isLoading, setIsLoading] = useState<boolean>(true);
|
||||||
const [isLoadingInputItems, setIsLoadingInputItems] = useState<boolean>(true);
|
const [isLoadingInputItems, setIsLoadingInputItems] = useState<boolean>(true);
|
||||||
|
|
@ -25,7 +25,7 @@ export default function ResponseDetailPage() {
|
||||||
|
|
||||||
// Helper function to convert ResponseObject to OpenAIResponse
|
// Helper function to convert ResponseObject to OpenAIResponse
|
||||||
const convertResponseObject = (
|
const convertResponseObject = (
|
||||||
responseData: ResponseObject,
|
responseData: ResponseObject
|
||||||
): OpenAIResponse => {
|
): OpenAIResponse => {
|
||||||
return {
|
return {
|
||||||
id: responseData.id,
|
id: responseData.id,
|
||||||
|
|
@ -73,12 +73,12 @@ export default function ResponseDetailPage() {
|
||||||
} else {
|
} else {
|
||||||
console.error(
|
console.error(
|
||||||
`Error fetching response detail for ID ${id}:`,
|
`Error fetching response detail for ID ${id}:`,
|
||||||
responseResult.reason,
|
responseResult.reason
|
||||||
);
|
);
|
||||||
setError(
|
setError(
|
||||||
responseResult.reason instanceof Error
|
responseResult.reason instanceof Error
|
||||||
? responseResult.reason
|
? responseResult.reason
|
||||||
: new Error("Failed to fetch response detail"),
|
: new Error("Failed to fetch response detail")
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -90,18 +90,18 @@ export default function ResponseDetailPage() {
|
||||||
} else {
|
} else {
|
||||||
console.error(
|
console.error(
|
||||||
`Error fetching input items for response ID ${id}:`,
|
`Error fetching input items for response ID ${id}:`,
|
||||||
inputItemsResult.reason,
|
inputItemsResult.reason
|
||||||
);
|
);
|
||||||
setInputItemsError(
|
setInputItemsError(
|
||||||
inputItemsResult.reason instanceof Error
|
inputItemsResult.reason instanceof Error
|
||||||
? inputItemsResult.reason
|
? inputItemsResult.reason
|
||||||
: new Error("Failed to fetch input items"),
|
: new Error("Failed to fetch input items")
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error(`Unexpected error fetching data for ID ${id}:`, err);
|
console.error(`Unexpected error fetching data for ID ${id}:`, err);
|
||||||
setError(
|
setError(
|
||||||
err instanceof Error ? err : new Error("Unexpected error occurred"),
|
err instanceof Error ? err : new Error("Unexpected error occurred")
|
||||||
);
|
);
|
||||||
} finally {
|
} finally {
|
||||||
setIsLoading(false);
|
setIsLoading(false);
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,10 @@ import {
|
||||||
PropertiesCard,
|
PropertiesCard,
|
||||||
PropertyItem,
|
PropertyItem,
|
||||||
} from "@/components/layout/detail-layout";
|
} from "@/components/layout/detail-layout";
|
||||||
import { PageBreadcrumb, BreadcrumbSegment } from "@/components/layout/page-breadcrumb";
|
import {
|
||||||
|
PageBreadcrumb,
|
||||||
|
BreadcrumbSegment,
|
||||||
|
} from "@/components/layout/page-breadcrumb";
|
||||||
|
|
||||||
export default function ContentDetailPage() {
|
export default function ContentDetailPage() {
|
||||||
const params = useParams();
|
const params = useParams();
|
||||||
|
|
@ -28,13 +31,13 @@ export default function ContentDetailPage() {
|
||||||
const contentId = params.contentId as string;
|
const contentId = params.contentId as string;
|
||||||
const client = useAuthClient();
|
const client = useAuthClient();
|
||||||
|
|
||||||
const getTextFromContent = (content: any): string => {
|
const getTextFromContent = (content: unknown): string => {
|
||||||
if (typeof content === 'string') {
|
if (typeof content === "string") {
|
||||||
return content;
|
return content;
|
||||||
} else if (content && content.type === 'text') {
|
} else if (content && content.type === "text") {
|
||||||
return content.text;
|
return content.text;
|
||||||
}
|
}
|
||||||
return '';
|
return "";
|
||||||
};
|
};
|
||||||
|
|
||||||
const [store, setStore] = useState<VectorStore | null>(null);
|
const [store, setStore] = useState<VectorStore | null>(null);
|
||||||
|
|
@ -44,7 +47,9 @@ export default function ContentDetailPage() {
|
||||||
const [error, setError] = useState<Error | null>(null);
|
const [error, setError] = useState<Error | null>(null);
|
||||||
const [isEditing, setIsEditing] = useState(false);
|
const [isEditing, setIsEditing] = useState(false);
|
||||||
const [editedContent, setEditedContent] = useState("");
|
const [editedContent, setEditedContent] = useState("");
|
||||||
const [editedMetadata, setEditedMetadata] = useState<Record<string, any>>({});
|
const [editedMetadata, setEditedMetadata] = useState<Record<string, unknown>>(
|
||||||
|
{}
|
||||||
|
);
|
||||||
const [isEditingEmbedding, setIsEditingEmbedding] = useState(false);
|
const [isEditingEmbedding, setIsEditingEmbedding] = useState(false);
|
||||||
const [editedEmbedding, setEditedEmbedding] = useState<number[]>([]);
|
const [editedEmbedding, setEditedEmbedding] = useState<number[]>([]);
|
||||||
|
|
||||||
|
|
@ -64,8 +69,13 @@ export default function ContentDetailPage() {
|
||||||
setFile(fileResponse as VectorStoreFile);
|
setFile(fileResponse as VectorStoreFile);
|
||||||
|
|
||||||
const contentsAPI = new ContentsAPI(client);
|
const contentsAPI = new ContentsAPI(client);
|
||||||
const contentsResponse = await contentsAPI.listContents(vectorStoreId, fileId);
|
const contentsResponse = await contentsAPI.listContents(
|
||||||
const targetContent = contentsResponse.data.find(c => c.id === contentId);
|
vectorStoreId,
|
||||||
|
fileId
|
||||||
|
);
|
||||||
|
const targetContent = contentsResponse.data.find(
|
||||||
|
c => c.id === contentId
|
||||||
|
);
|
||||||
|
|
||||||
if (targetContent) {
|
if (targetContent) {
|
||||||
setContent(targetContent);
|
setContent(targetContent);
|
||||||
|
|
@ -76,7 +86,9 @@ export default function ContentDetailPage() {
|
||||||
throw new Error(`Content ${contentId} not found`);
|
throw new Error(`Content ${contentId} not found`);
|
||||||
}
|
}
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
setError(err instanceof Error ? err : new Error("Failed to load content."));
|
setError(
|
||||||
|
err instanceof Error ? err : new Error("Failed to load content.")
|
||||||
|
);
|
||||||
} finally {
|
} finally {
|
||||||
setIsLoading(false);
|
setIsLoading(false);
|
||||||
}
|
}
|
||||||
|
|
@ -88,7 +100,8 @@ export default function ContentDetailPage() {
|
||||||
if (!content) return;
|
if (!content) return;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const updates: { content?: string; metadata?: Record<string, any> } = {};
|
const updates: { content?: string; metadata?: Record<string, unknown> } =
|
||||||
|
{};
|
||||||
|
|
||||||
if (editedContent !== getTextFromContent(content.content)) {
|
if (editedContent !== getTextFromContent(content.content)) {
|
||||||
updates.content = editedContent;
|
updates.content = editedContent;
|
||||||
|
|
@ -100,25 +113,32 @@ export default function ContentDetailPage() {
|
||||||
|
|
||||||
if (Object.keys(updates).length > 0) {
|
if (Object.keys(updates).length > 0) {
|
||||||
const contentsAPI = new ContentsAPI(client);
|
const contentsAPI = new ContentsAPI(client);
|
||||||
const updatedContent = await contentsAPI.updateContent(vectorStoreId, fileId, contentId, updates);
|
const updatedContent = await contentsAPI.updateContent(
|
||||||
|
vectorStoreId,
|
||||||
|
fileId,
|
||||||
|
contentId,
|
||||||
|
updates
|
||||||
|
);
|
||||||
setContent(updatedContent);
|
setContent(updatedContent);
|
||||||
}
|
}
|
||||||
|
|
||||||
setIsEditing(false);
|
setIsEditing(false);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error('Failed to update content:', err);
|
console.error("Failed to update content:", err);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleDelete = async () => {
|
const handleDelete = async () => {
|
||||||
if (!confirm('Are you sure you want to delete this content?')) return;
|
if (!confirm("Are you sure you want to delete this content?")) return;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const contentsAPI = new ContentsAPI(client);
|
const contentsAPI = new ContentsAPI(client);
|
||||||
await contentsAPI.deleteContent(vectorStoreId, fileId, contentId);
|
await contentsAPI.deleteContent(vectorStoreId, fileId, contentId);
|
||||||
router.push(`/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents`);
|
router.push(
|
||||||
|
`/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents`
|
||||||
|
);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error('Failed to delete content:', err);
|
console.error("Failed to delete content:", err);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -134,10 +154,19 @@ export default function ContentDetailPage() {
|
||||||
|
|
||||||
const breadcrumbSegments: BreadcrumbSegment[] = [
|
const breadcrumbSegments: BreadcrumbSegment[] = [
|
||||||
{ label: "Vector Stores", href: "/logs/vector-stores" },
|
{ label: "Vector Stores", href: "/logs/vector-stores" },
|
||||||
{ label: store?.name || vectorStoreId, href: `/logs/vector-stores/${vectorStoreId}` },
|
{
|
||||||
|
label: store?.name || vectorStoreId,
|
||||||
|
href: `/logs/vector-stores/${vectorStoreId}`,
|
||||||
|
},
|
||||||
{ label: "Files", href: `/logs/vector-stores/${vectorStoreId}` },
|
{ label: "Files", href: `/logs/vector-stores/${vectorStoreId}` },
|
||||||
{ label: fileId, href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}` },
|
{
|
||||||
{ label: "Contents", href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents` },
|
label: fileId,
|
||||||
|
href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: "Contents",
|
||||||
|
href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents`,
|
||||||
|
},
|
||||||
{ label: contentId },
|
{ label: contentId },
|
||||||
];
|
];
|
||||||
|
|
||||||
|
|
@ -186,7 +215,7 @@ export default function ContentDetailPage() {
|
||||||
{isEditing ? (
|
{isEditing ? (
|
||||||
<textarea
|
<textarea
|
||||||
value={editedContent}
|
value={editedContent}
|
||||||
onChange={(e) => setEditedContent(e.target.value)}
|
onChange={e => setEditedContent(e.target.value)}
|
||||||
className="w-full h-64 p-3 border rounded-md resize-none font-mono text-sm"
|
className="w-full h-64 p-3 border rounded-md resize-none font-mono text-sm"
|
||||||
placeholder="Enter content..."
|
placeholder="Enter content..."
|
||||||
/>
|
/>
|
||||||
|
|
@ -206,16 +235,23 @@ export default function ContentDetailPage() {
|
||||||
<div className="flex gap-2">
|
<div className="flex gap-2">
|
||||||
{isEditingEmbedding ? (
|
{isEditingEmbedding ? (
|
||||||
<>
|
<>
|
||||||
<Button size="sm" onClick={() => {
|
<Button
|
||||||
setIsEditingEmbedding(false);
|
size="sm"
|
||||||
}}>
|
onClick={() => {
|
||||||
|
setIsEditingEmbedding(false);
|
||||||
|
}}
|
||||||
|
>
|
||||||
<Save className="h-4 w-4 mr-1" />
|
<Save className="h-4 w-4 mr-1" />
|
||||||
Save
|
Save
|
||||||
</Button>
|
</Button>
|
||||||
<Button size="sm" variant="outline" onClick={() => {
|
<Button
|
||||||
setEditedEmbedding(content?.embedding || []);
|
size="sm"
|
||||||
setIsEditingEmbedding(false);
|
variant="outline"
|
||||||
}}>
|
onClick={() => {
|
||||||
|
setEditedEmbedding(content?.embedding || []);
|
||||||
|
setIsEditingEmbedding(false);
|
||||||
|
}}
|
||||||
|
>
|
||||||
<X className="h-4 w-4 mr-1" />
|
<X className="h-4 w-4 mr-1" />
|
||||||
Cancel
|
Cancel
|
||||||
</Button>
|
</Button>
|
||||||
|
|
@ -237,14 +273,16 @@ export default function ContentDetailPage() {
|
||||||
</p>
|
</p>
|
||||||
<textarea
|
<textarea
|
||||||
value={JSON.stringify(editedEmbedding, null, 2)}
|
value={JSON.stringify(editedEmbedding, null, 2)}
|
||||||
onChange={(e) => {
|
onChange={e => {
|
||||||
try {
|
try {
|
||||||
const parsed = JSON.parse(e.target.value);
|
const parsed = JSON.parse(e.target.value);
|
||||||
if (Array.isArray(parsed) && parsed.every(v => typeof v === 'number')) {
|
if (
|
||||||
|
Array.isArray(parsed) &&
|
||||||
|
parsed.every(v => typeof v === "number")
|
||||||
|
) {
|
||||||
setEditedEmbedding(parsed);
|
setEditedEmbedding(parsed);
|
||||||
}
|
}
|
||||||
} catch {
|
} catch {}
|
||||||
}
|
|
||||||
}}
|
}}
|
||||||
className="w-full h-32 p-3 border rounded-md resize-none font-mono text-xs"
|
className="w-full h-32 p-3 border rounded-md resize-none font-mono text-xs"
|
||||||
placeholder="Enter embedding as JSON array..."
|
placeholder="Enter embedding as JSON array..."
|
||||||
|
|
@ -259,8 +297,15 @@ export default function ContentDetailPage() {
|
||||||
</div>
|
</div>
|
||||||
<div className="p-3 bg-gray-50 dark:bg-gray-800 rounded-md max-h-32 overflow-y-auto">
|
<div className="p-3 bg-gray-50 dark:bg-gray-800 rounded-md max-h-32 overflow-y-auto">
|
||||||
<pre className="whitespace-pre-wrap font-mono text-xs text-gray-900 dark:text-gray-100">
|
<pre className="whitespace-pre-wrap font-mono text-xs text-gray-900 dark:text-gray-100">
|
||||||
[{content.embedding.slice(0, 20).map(v => v.toFixed(6)).join(', ')}
|
[
|
||||||
{content.embedding.length > 20 ? `\n... and ${content.embedding.length - 20} more values` : ''}]
|
{content.embedding
|
||||||
|
.slice(0, 20)
|
||||||
|
.map(v => v.toFixed(6))
|
||||||
|
.join(", ")}
|
||||||
|
{content.embedding.length > 20
|
||||||
|
? `\n... and ${content.embedding.length - 20} more values`
|
||||||
|
: ""}
|
||||||
|
]
|
||||||
</pre>
|
</pre>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
@ -284,7 +329,7 @@ export default function ContentDetailPage() {
|
||||||
<div key={key} className="flex gap-2">
|
<div key={key} className="flex gap-2">
|
||||||
<Input
|
<Input
|
||||||
value={key}
|
value={key}
|
||||||
onChange={(e) => {
|
onChange={e => {
|
||||||
const newMetadata = { ...editedMetadata };
|
const newMetadata = { ...editedMetadata };
|
||||||
delete newMetadata[key];
|
delete newMetadata[key];
|
||||||
newMetadata[e.target.value] = value;
|
newMetadata[e.target.value] = value;
|
||||||
|
|
@ -294,11 +339,13 @@ export default function ContentDetailPage() {
|
||||||
className="flex-1"
|
className="flex-1"
|
||||||
/>
|
/>
|
||||||
<Input
|
<Input
|
||||||
value={typeof value === 'string' ? value : JSON.stringify(value)}
|
value={
|
||||||
onChange={(e) => {
|
typeof value === "string" ? value : JSON.stringify(value)
|
||||||
|
}
|
||||||
|
onChange={e => {
|
||||||
setEditedMetadata({
|
setEditedMetadata({
|
||||||
...editedMetadata,
|
...editedMetadata,
|
||||||
[key]: e.target.value
|
[key]: e.target.value,
|
||||||
});
|
});
|
||||||
}}
|
}}
|
||||||
placeholder="Value"
|
placeholder="Value"
|
||||||
|
|
@ -312,7 +359,7 @@ export default function ContentDetailPage() {
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
setEditedMetadata({
|
setEditedMetadata({
|
||||||
...editedMetadata,
|
...editedMetadata,
|
||||||
['']: ''
|
[""]: "",
|
||||||
});
|
});
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
|
|
@ -325,7 +372,7 @@ export default function ContentDetailPage() {
|
||||||
<div key={key} className="flex justify-between py-1">
|
<div key={key} className="flex justify-between py-1">
|
||||||
<span className="font-medium text-gray-600">{key}:</span>
|
<span className="font-medium text-gray-600">{key}:</span>
|
||||||
<span className="font-mono text-sm">
|
<span className="font-mono text-sm">
|
||||||
{typeof value === 'string' ? value : JSON.stringify(value)}
|
{typeof value === "string" ? value : JSON.stringify(value)}
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
))}
|
))}
|
||||||
|
|
@ -351,15 +398,15 @@ export default function ContentDetailPage() {
|
||||||
value={`${getTextFromContent(content.content).length} chars`}
|
value={`${getTextFromContent(content.content).length} chars`}
|
||||||
/>
|
/>
|
||||||
{content.metadata.chunk_window && (
|
{content.metadata.chunk_window && (
|
||||||
<PropertyItem
|
<PropertyItem label="Position" value={content.metadata.chunk_window} />
|
||||||
label="Position"
|
|
||||||
value={content.metadata.chunk_window}
|
|
||||||
/>
|
|
||||||
)}
|
)}
|
||||||
{file && (
|
{file && (
|
||||||
<>
|
<>
|
||||||
<PropertyItem label="File Status" value={file.status} />
|
<PropertyItem label="File Status" value={file.status} />
|
||||||
<PropertyItem label="File Usage" value={`${file.usage_bytes} bytes`} />
|
<PropertyItem
|
||||||
|
label="File Usage"
|
||||||
|
value={`${file.usage_bytes} bytes`}
|
||||||
|
/>
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
{store && (
|
{store && (
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,10 @@ import {
|
||||||
PropertiesCard,
|
PropertiesCard,
|
||||||
PropertyItem,
|
PropertyItem,
|
||||||
} from "@/components/layout/detail-layout";
|
} from "@/components/layout/detail-layout";
|
||||||
import { PageBreadcrumb, BreadcrumbSegment } from "@/components/layout/page-breadcrumb";
|
import {
|
||||||
|
PageBreadcrumb,
|
||||||
|
BreadcrumbSegment,
|
||||||
|
} from "@/components/layout/page-breadcrumb";
|
||||||
import {
|
import {
|
||||||
Table,
|
Table,
|
||||||
TableBody,
|
TableBody,
|
||||||
|
|
@ -36,23 +39,21 @@ export default function ContentsListPage() {
|
||||||
const fileId = params.fileId as string;
|
const fileId = params.fileId as string;
|
||||||
const client = useAuthClient();
|
const client = useAuthClient();
|
||||||
|
|
||||||
const getTextFromContent = (content: any): string => {
|
const getTextFromContent = (content: unknown): string => {
|
||||||
if (typeof content === 'string') {
|
if (typeof content === "string") {
|
||||||
return content;
|
return content;
|
||||||
} else if (content && content.type === 'text') {
|
} else if (content && content.type === "text") {
|
||||||
return content.text;
|
return content.text;
|
||||||
}
|
}
|
||||||
return '';
|
return "";
|
||||||
};
|
};
|
||||||
|
|
||||||
const [store, setStore] = useState<VectorStore | null>(null);
|
const [store, setStore] = useState<VectorStore | null>(null);
|
||||||
const [file, setFile] = useState<VectorStoreFile | null>(null);
|
const [file, setFile] = useState<VectorStoreFile | null>(null);
|
||||||
const [contents, setContents] = useState<VectorStoreContentItem[]>([]);
|
const [contents, setContents] = useState<VectorStoreContentItem[]>([]);
|
||||||
const [isLoadingStore, setIsLoadingStore] = useState(true);
|
const [isLoadingStore, setIsLoadingStore] = useState(true);
|
||||||
const [isLoadingFile, setIsLoadingFile] = useState(true);
|
|
||||||
const [isLoadingContents, setIsLoadingContents] = useState(true);
|
const [isLoadingContents, setIsLoadingContents] = useState(true);
|
||||||
const [errorStore, setErrorStore] = useState<Error | null>(null);
|
const [errorStore, setErrorStore] = useState<Error | null>(null);
|
||||||
const [errorFile, setErrorFile] = useState<Error | null>(null);
|
|
||||||
const [errorContents, setErrorContents] = useState<Error | null>(null);
|
const [errorContents, setErrorContents] = useState<Error | null>(null);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
|
@ -65,7 +66,9 @@ export default function ContentsListPage() {
|
||||||
const response = await client.vectorStores.retrieve(vectorStoreId);
|
const response = await client.vectorStores.retrieve(vectorStoreId);
|
||||||
setStore(response as VectorStore);
|
setStore(response as VectorStore);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
setErrorStore(err instanceof Error ? err : new Error("Failed to load vector store."));
|
setErrorStore(
|
||||||
|
err instanceof Error ? err : new Error("Failed to load vector store.")
|
||||||
|
);
|
||||||
} finally {
|
} finally {
|
||||||
setIsLoadingStore(false);
|
setIsLoadingStore(false);
|
||||||
}
|
}
|
||||||
|
|
@ -80,10 +83,15 @@ export default function ContentsListPage() {
|
||||||
setIsLoadingFile(true);
|
setIsLoadingFile(true);
|
||||||
setErrorFile(null);
|
setErrorFile(null);
|
||||||
try {
|
try {
|
||||||
const response = await client.vectorStores.files.retrieve(vectorStoreId, fileId);
|
const response = await client.vectorStores.files.retrieve(
|
||||||
|
vectorStoreId,
|
||||||
|
fileId
|
||||||
|
);
|
||||||
setFile(response as VectorStoreFile);
|
setFile(response as VectorStoreFile);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
setErrorFile(err instanceof Error ? err : new Error("Failed to load file."));
|
setErrorFile(
|
||||||
|
err instanceof Error ? err : new Error("Failed to load file.")
|
||||||
|
);
|
||||||
} finally {
|
} finally {
|
||||||
setIsLoadingFile(false);
|
setIsLoadingFile(false);
|
||||||
}
|
}
|
||||||
|
|
@ -99,10 +107,16 @@ export default function ContentsListPage() {
|
||||||
setErrorContents(null);
|
setErrorContents(null);
|
||||||
try {
|
try {
|
||||||
const contentsAPI = new ContentsAPI(client);
|
const contentsAPI = new ContentsAPI(client);
|
||||||
const contentsResponse = await contentsAPI.listContents(vectorStoreId, fileId, { limit: 100 });
|
const contentsResponse = await contentsAPI.listContents(
|
||||||
|
vectorStoreId,
|
||||||
|
fileId,
|
||||||
|
{ limit: 100 }
|
||||||
|
);
|
||||||
setContents(contentsResponse.data);
|
setContents(contentsResponse.data);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
setErrorContents(err instanceof Error ? err : new Error("Failed to load contents."));
|
setErrorContents(
|
||||||
|
err instanceof Error ? err : new Error("Failed to load contents.")
|
||||||
|
);
|
||||||
} finally {
|
} finally {
|
||||||
setIsLoadingContents(false);
|
setIsLoadingContents(false);
|
||||||
}
|
}
|
||||||
|
|
@ -116,26 +130,36 @@ export default function ContentsListPage() {
|
||||||
await contentsAPI.deleteContent(vectorStoreId, fileId, contentId);
|
await contentsAPI.deleteContent(vectorStoreId, fileId, contentId);
|
||||||
setContents(contents.filter(content => content.id !== contentId));
|
setContents(contents.filter(content => content.id !== contentId));
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error('Failed to delete content:', err);
|
console.error("Failed to delete content:", err);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleViewContent = (contentId: string) => {
|
const handleViewContent = (contentId: string) => {
|
||||||
router.push(`/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents/${contentId}`);
|
router.push(
|
||||||
|
`/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents/${contentId}`
|
||||||
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
const title = `Contents in File: ${fileId}`;
|
const title = `Contents in File: ${fileId}`;
|
||||||
|
|
||||||
const breadcrumbSegments: BreadcrumbSegment[] = [
|
const breadcrumbSegments: BreadcrumbSegment[] = [
|
||||||
{ label: "Vector Stores", href: "/logs/vector-stores" },
|
{ label: "Vector Stores", href: "/logs/vector-stores" },
|
||||||
{ label: store?.name || vectorStoreId, href: `/logs/vector-stores/${vectorStoreId}` },
|
{
|
||||||
|
label: store?.name || vectorStoreId,
|
||||||
|
href: `/logs/vector-stores/${vectorStoreId}`,
|
||||||
|
},
|
||||||
{ label: "Files", href: `/logs/vector-stores/${vectorStoreId}` },
|
{ label: "Files", href: `/logs/vector-stores/${vectorStoreId}` },
|
||||||
{ label: fileId, href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}` },
|
{
|
||||||
|
label: fileId,
|
||||||
|
href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}`,
|
||||||
|
},
|
||||||
{ label: "Contents" },
|
{ label: "Contents" },
|
||||||
];
|
];
|
||||||
|
|
||||||
if (errorStore) {
|
if (errorStore) {
|
||||||
return <DetailErrorView title={title} id={vectorStoreId} error={errorStore} />;
|
return (
|
||||||
|
<DetailErrorView title={title} id={vectorStoreId} error={errorStore} />
|
||||||
|
);
|
||||||
}
|
}
|
||||||
if (isLoadingStore) {
|
if (isLoadingStore) {
|
||||||
return <DetailLoadingView title={title} />;
|
return <DetailLoadingView title={title} />;
|
||||||
|
|
@ -175,7 +199,7 @@ export default function ContentsListPage() {
|
||||||
</TableRow>
|
</TableRow>
|
||||||
</TableHeader>
|
</TableHeader>
|
||||||
<TableBody>
|
<TableBody>
|
||||||
{contents.map((content) => (
|
{contents.map(content => (
|
||||||
<TableRow key={content.id}>
|
<TableRow key={content.id}>
|
||||||
<TableCell className="font-mono text-xs">
|
<TableCell className="font-mono text-xs">
|
||||||
<Button
|
<Button
|
||||||
|
|
@ -189,7 +213,10 @@ export default function ContentsListPage() {
|
||||||
</TableCell>
|
</TableCell>
|
||||||
<TableCell>
|
<TableCell>
|
||||||
<div className="max-w-md">
|
<div className="max-w-md">
|
||||||
<p className="text-sm truncate" title={getTextFromContent(content.content)}>
|
<p
|
||||||
|
className="text-sm truncate"
|
||||||
|
title={getTextFromContent(content.content)}
|
||||||
|
>
|
||||||
{getTextFromContent(content.content)}
|
{getTextFromContent(content.content)}
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
@ -197,12 +224,25 @@ export default function ContentsListPage() {
|
||||||
<TableCell className="text-xs text-gray-500">
|
<TableCell className="text-xs text-gray-500">
|
||||||
{content.embedding && content.embedding.length > 0 ? (
|
{content.embedding && content.embedding.length > 0 ? (
|
||||||
<div className="max-w-xs">
|
<div className="max-w-xs">
|
||||||
<span className="font-mono text-xs bg-gray-100 dark:bg-gray-800 rounded px-1 py-0.5" title={`${content.embedding.length}D vector: [${content.embedding.slice(0, 3).map(v => v.toFixed(3)).join(', ')}...]`}>
|
<span
|
||||||
[{content.embedding.slice(0, 3).map(v => v.toFixed(3)).join(', ')}...] ({content.embedding.length}D)
|
className="font-mono text-xs bg-gray-100 dark:bg-gray-800 rounded px-1 py-0.5"
|
||||||
|
title={`${content.embedding.length}D vector: [${content.embedding
|
||||||
|
.slice(0, 3)
|
||||||
|
.map(v => v.toFixed(3))
|
||||||
|
.join(", ")}...]`}
|
||||||
|
>
|
||||||
|
[
|
||||||
|
{content.embedding
|
||||||
|
.slice(0, 3)
|
||||||
|
.map(v => v.toFixed(3))
|
||||||
|
.join(", ")}
|
||||||
|
...] ({content.embedding.length}D)
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<span className="text-gray-400 dark:text-gray-500 italic">No embedding</span>
|
<span className="text-gray-400 dark:text-gray-500 italic">
|
||||||
|
No embedding
|
||||||
|
</span>
|
||||||
)}
|
)}
|
||||||
</TableCell>
|
</TableCell>
|
||||||
<TableCell className="text-xs text-gray-500">
|
<TableCell className="text-xs text-gray-500">
|
||||||
|
|
@ -211,7 +251,9 @@ export default function ContentsListPage() {
|
||||||
: `${content.metadata.content_length || 0} chars`}
|
: `${content.metadata.content_length || 0} chars`}
|
||||||
</TableCell>
|
</TableCell>
|
||||||
<TableCell className="text-xs">
|
<TableCell className="text-xs">
|
||||||
{new Date(content.created_timestamp * 1000).toLocaleString()}
|
{new Date(
|
||||||
|
content.created_timestamp * 1000
|
||||||
|
).toLocaleString()}
|
||||||
</TableCell>
|
</TableCell>
|
||||||
<TableCell>
|
<TableCell>
|
||||||
<div className="flex gap-1">
|
<div className="flex gap-1">
|
||||||
|
|
|
||||||
|
|
@ -4,9 +4,12 @@ import { useEffect, useState } from "react";
|
||||||
import { useParams, useRouter } from "next/navigation";
|
import { useParams, useRouter } from "next/navigation";
|
||||||
import { useAuthClient } from "@/hooks/use-auth-client";
|
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||||
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
|
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
|
||||||
import type { VectorStoreFile, FileContentResponse } from "llama-stack-client/resources/vector-stores/files";
|
import type {
|
||||||
|
VectorStoreFile,
|
||||||
|
FileContentResponse,
|
||||||
|
} from "llama-stack-client/resources/vector-stores/files";
|
||||||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||||
import { Skeleton } from '@/components/ui/skeleton';
|
import { Skeleton } from "@/components/ui/skeleton";
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
import { List } from "lucide-react";
|
import { List } from "lucide-react";
|
||||||
import {
|
import {
|
||||||
|
|
@ -17,7 +20,10 @@ import {
|
||||||
PropertiesCard,
|
PropertiesCard,
|
||||||
PropertyItem,
|
PropertyItem,
|
||||||
} from "@/components/layout/detail-layout";
|
} from "@/components/layout/detail-layout";
|
||||||
import { PageBreadcrumb, BreadcrumbSegment } from "@/components/layout/page-breadcrumb";
|
import {
|
||||||
|
PageBreadcrumb,
|
||||||
|
BreadcrumbSegment,
|
||||||
|
} from "@/components/layout/page-breadcrumb";
|
||||||
|
|
||||||
export default function FileDetailPage() {
|
export default function FileDetailPage() {
|
||||||
const params = useParams();
|
const params = useParams();
|
||||||
|
|
@ -46,7 +52,9 @@ export default function FileDetailPage() {
|
||||||
const response = await client.vectorStores.retrieve(vectorStoreId);
|
const response = await client.vectorStores.retrieve(vectorStoreId);
|
||||||
setStore(response as VectorStore);
|
setStore(response as VectorStore);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
setErrorStore(err instanceof Error ? err : new Error("Failed to load vector store."));
|
setErrorStore(
|
||||||
|
err instanceof Error ? err : new Error("Failed to load vector store.")
|
||||||
|
);
|
||||||
} finally {
|
} finally {
|
||||||
setIsLoadingStore(false);
|
setIsLoadingStore(false);
|
||||||
}
|
}
|
||||||
|
|
@ -61,10 +69,15 @@ export default function FileDetailPage() {
|
||||||
setIsLoadingFile(true);
|
setIsLoadingFile(true);
|
||||||
setErrorFile(null);
|
setErrorFile(null);
|
||||||
try {
|
try {
|
||||||
const response = await client.vectorStores.files.retrieve(vectorStoreId, fileId);
|
const response = await client.vectorStores.files.retrieve(
|
||||||
|
vectorStoreId,
|
||||||
|
fileId
|
||||||
|
);
|
||||||
setFile(response as VectorStoreFile);
|
setFile(response as VectorStoreFile);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
setErrorFile(err instanceof Error ? err : new Error("Failed to load file."));
|
setErrorFile(
|
||||||
|
err instanceof Error ? err : new Error("Failed to load file.")
|
||||||
|
);
|
||||||
} finally {
|
} finally {
|
||||||
setIsLoadingFile(false);
|
setIsLoadingFile(false);
|
||||||
}
|
}
|
||||||
|
|
@ -79,10 +92,15 @@ export default function FileDetailPage() {
|
||||||
setIsLoadingContents(true);
|
setIsLoadingContents(true);
|
||||||
setErrorContents(null);
|
setErrorContents(null);
|
||||||
try {
|
try {
|
||||||
const response = await client.vectorStores.files.content(vectorStoreId, fileId);
|
const response = await client.vectorStores.files.content(
|
||||||
|
vectorStoreId,
|
||||||
|
fileId
|
||||||
|
);
|
||||||
setContents(response);
|
setContents(response);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
setErrorContents(err instanceof Error ? err : new Error("Failed to load contents."));
|
setErrorContents(
|
||||||
|
err instanceof Error ? err : new Error("Failed to load contents.")
|
||||||
|
);
|
||||||
} finally {
|
} finally {
|
||||||
setIsLoadingContents(false);
|
setIsLoadingContents(false);
|
||||||
}
|
}
|
||||||
|
|
@ -91,20 +109,27 @@ export default function FileDetailPage() {
|
||||||
}, [vectorStoreId, fileId, client]);
|
}, [vectorStoreId, fileId, client]);
|
||||||
|
|
||||||
const handleViewContents = () => {
|
const handleViewContents = () => {
|
||||||
router.push(`/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents`);
|
router.push(
|
||||||
|
`/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents`
|
||||||
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
const title = `File: ${fileId}`;
|
const title = `File: ${fileId}`;
|
||||||
|
|
||||||
const breadcrumbSegments: BreadcrumbSegment[] = [
|
const breadcrumbSegments: BreadcrumbSegment[] = [
|
||||||
{ label: "Vector Stores", href: "/logs/vector-stores" },
|
{ label: "Vector Stores", href: "/logs/vector-stores" },
|
||||||
{ label: store?.name || vectorStoreId, href: `/logs/vector-stores/${vectorStoreId}` },
|
{
|
||||||
|
label: store?.name || vectorStoreId,
|
||||||
|
href: `/logs/vector-stores/${vectorStoreId}`,
|
||||||
|
},
|
||||||
{ label: "Files", href: `/logs/vector-stores/${vectorStoreId}` },
|
{ label: "Files", href: `/logs/vector-stores/${vectorStoreId}` },
|
||||||
{ label: fileId },
|
{ label: fileId },
|
||||||
];
|
];
|
||||||
|
|
||||||
if (errorStore) {
|
if (errorStore) {
|
||||||
return <DetailErrorView title={title} id={vectorStoreId} error={errorStore} />;
|
return (
|
||||||
|
<DetailErrorView title={title} id={vectorStoreId} error={errorStore} />
|
||||||
|
);
|
||||||
}
|
}
|
||||||
if (isLoadingStore) {
|
if (isLoadingStore) {
|
||||||
return <DetailLoadingView title={title} />;
|
return <DetailLoadingView title={title} />;
|
||||||
|
|
@ -136,19 +161,29 @@ export default function FileDetailPage() {
|
||||||
<h3 className="text-lg font-medium mb-2">File Details</h3>
|
<h3 className="text-lg font-medium mb-2">File Details</h3>
|
||||||
<div className="grid grid-cols-2 gap-4 text-sm">
|
<div className="grid grid-cols-2 gap-4 text-sm">
|
||||||
<div>
|
<div>
|
||||||
<span className="font-medium text-gray-600 dark:text-gray-400">Status:</span>
|
<span className="font-medium text-gray-600 dark:text-gray-400">
|
||||||
|
Status:
|
||||||
|
</span>
|
||||||
<span className="ml-2">{file.status}</span>
|
<span className="ml-2">{file.status}</span>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<span className="font-medium text-gray-600 dark:text-gray-400">Size:</span>
|
<span className="font-medium text-gray-600 dark:text-gray-400">
|
||||||
|
Size:
|
||||||
|
</span>
|
||||||
<span className="ml-2">{file.usage_bytes} bytes</span>
|
<span className="ml-2">{file.usage_bytes} bytes</span>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<span className="font-medium text-gray-600 dark:text-gray-400">Created:</span>
|
<span className="font-medium text-gray-600 dark:text-gray-400">
|
||||||
<span className="ml-2">{new Date(file.created_at * 1000).toLocaleString()}</span>
|
Created:
|
||||||
|
</span>
|
||||||
|
<span className="ml-2">
|
||||||
|
{new Date(file.created_at * 1000).toLocaleString()}
|
||||||
|
</span>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<span className="font-medium text-gray-600 dark:text-gray-400">Content Strategy:</span>
|
<span className="font-medium text-gray-600 dark:text-gray-400">
|
||||||
|
Content Strategy:
|
||||||
|
</span>
|
||||||
<span className="ml-2">{file.chunking_strategy.type}</span>
|
<span className="ml-2">{file.chunking_strategy.type}</span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
@ -166,9 +201,7 @@ export default function FileDetailPage() {
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<p className="text-gray-500 italic text-sm">
|
<p className="text-gray-500 italic text-sm">File not found.</p>
|
||||||
File not found.
|
|
||||||
</p>
|
|
||||||
)}
|
)}
|
||||||
</CardContent>
|
</CardContent>
|
||||||
</Card>
|
</Card>
|
||||||
|
|
@ -192,16 +225,27 @@ export default function FileDetailPage() {
|
||||||
<div className="space-y-3">
|
<div className="space-y-3">
|
||||||
<div className="grid grid-cols-2 gap-4 text-sm">
|
<div className="grid grid-cols-2 gap-4 text-sm">
|
||||||
<div>
|
<div>
|
||||||
<span className="font-medium text-gray-600 dark:text-gray-400">Content Items:</span>
|
<span className="font-medium text-gray-600 dark:text-gray-400">
|
||||||
|
Content Items:
|
||||||
|
</span>
|
||||||
<span className="ml-2">{contents.content.length}</span>
|
<span className="ml-2">{contents.content.length}</span>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<span className="font-medium text-gray-600 dark:text-gray-400">Total Characters:</span>
|
<span className="font-medium text-gray-600 dark:text-gray-400">
|
||||||
<span className="ml-2">{contents.content.reduce((total, item) => total + item.text.length, 0)}</span>
|
Total Characters:
|
||||||
|
</span>
|
||||||
|
<span className="ml-2">
|
||||||
|
{contents.content.reduce(
|
||||||
|
(total, item) => total + item.text.length,
|
||||||
|
0
|
||||||
|
)}
|
||||||
|
</span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div className="pt-2">
|
<div className="pt-2">
|
||||||
<span className="text-sm font-medium text-gray-600 dark:text-gray-400">Preview:</span>
|
<span className="text-sm font-medium text-gray-600 dark:text-gray-400">
|
||||||
|
Preview:
|
||||||
|
</span>
|
||||||
<div className="mt-1 bg-gray-50 dark:bg-gray-800 rounded-md p-3">
|
<div className="mt-1 bg-gray-50 dark:bg-gray-800 rounded-md p-3">
|
||||||
<p className="text-sm text-gray-900 dark:text-gray-100 line-clamp-3">
|
<p className="text-sm text-gray-900 dark:text-gray-100 line-clamp-3">
|
||||||
{contents.content[0]?.text.substring(0, 200)}...
|
{contents.content[0]?.text.substring(0, 200)}...
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { useEffect, useState } from "react";
|
import { useEffect, useState } from "react";
|
||||||
import { useParams, useRouter } from "next/navigation";
|
import { useParams } from "next/navigation";
|
||||||
import { useAuthClient } from "@/hooks/use-auth-client";
|
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||||
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
|
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
|
||||||
import type { VectorStoreFile } from "llama-stack-client/resources/vector-stores/files";
|
import type { VectorStoreFile } from "llama-stack-client/resources/vector-stores/files";
|
||||||
|
|
@ -11,7 +11,6 @@ export default function VectorStoreDetailPage() {
|
||||||
const params = useParams();
|
const params = useParams();
|
||||||
const id = params.id as string;
|
const id = params.id as string;
|
||||||
const client = useAuthClient();
|
const client = useAuthClient();
|
||||||
const router = useRouter();
|
|
||||||
|
|
||||||
const [store, setStore] = useState<VectorStore | null>(null);
|
const [store, setStore] = useState<VectorStore | null>(null);
|
||||||
const [files, setFiles] = useState<VectorStoreFile[]>([]);
|
const [files, setFiles] = useState<VectorStoreFile[]>([]);
|
||||||
|
|
@ -34,9 +33,7 @@ export default function VectorStoreDetailPage() {
|
||||||
setStore(response as VectorStore);
|
setStore(response as VectorStore);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
setErrorStore(
|
setErrorStore(
|
||||||
err instanceof Error
|
err instanceof Error ? err : new Error("Failed to load vector store.")
|
||||||
? err
|
|
||||||
: new Error("Failed to load vector store."),
|
|
||||||
);
|
);
|
||||||
} finally {
|
} finally {
|
||||||
setIsLoadingStore(false);
|
setIsLoadingStore(false);
|
||||||
|
|
@ -55,18 +52,18 @@ export default function VectorStoreDetailPage() {
|
||||||
setIsLoadingFiles(true);
|
setIsLoadingFiles(true);
|
||||||
setErrorFiles(null);
|
setErrorFiles(null);
|
||||||
try {
|
try {
|
||||||
const result = await client.vectorStores.files.list(id as any);
|
const result = await client.vectorStores.files.list(id);
|
||||||
setFiles((result as any).data);
|
setFiles((result as { data: VectorStoreFile[] }).data);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
setErrorFiles(
|
setErrorFiles(
|
||||||
err instanceof Error ? err : new Error("Failed to load files."),
|
err instanceof Error ? err : new Error("Failed to load files.")
|
||||||
);
|
);
|
||||||
} finally {
|
} finally {
|
||||||
setIsLoadingFiles(false);
|
setIsLoadingFiles(false);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
fetchFiles();
|
fetchFiles();
|
||||||
}, [id]);
|
}, [id, client.vectorStores.files]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<VectorStoreDetailView
|
<VectorStoreDetailView
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import React from "react";
|
import React from "react";
|
||||||
import { useAuthClient } from "@/hooks/use-auth-client";
|
|
||||||
import type {
|
import type {
|
||||||
ListVectorStoresResponse,
|
ListVectorStoresResponse,
|
||||||
VectorStore,
|
VectorStore,
|
||||||
|
|
@ -12,7 +11,6 @@ import { Button } from "@/components/ui/button";
|
||||||
import {
|
import {
|
||||||
Table,
|
Table,
|
||||||
TableBody,
|
TableBody,
|
||||||
TableCaption,
|
|
||||||
TableCell,
|
TableCell,
|
||||||
TableHead,
|
TableHead,
|
||||||
TableHeader,
|
TableHeader,
|
||||||
|
|
@ -21,7 +19,6 @@ import {
|
||||||
import { Skeleton } from "@/components/ui/skeleton";
|
import { Skeleton } from "@/components/ui/skeleton";
|
||||||
|
|
||||||
export default function VectorStoresPage() {
|
export default function VectorStoresPage() {
|
||||||
const client = useAuthClient();
|
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
const {
|
const {
|
||||||
data: stores,
|
data: stores,
|
||||||
|
|
@ -37,7 +34,7 @@ export default function VectorStoresPage() {
|
||||||
after: params.after,
|
after: params.after,
|
||||||
limit: params.limit,
|
limit: params.limit,
|
||||||
order: params.order,
|
order: params.order,
|
||||||
} as any);
|
} as Parameters<typeof client.vectorStores.list>[0]);
|
||||||
return response as ListVectorStoresResponse;
|
return response as ListVectorStoresResponse;
|
||||||
},
|
},
|
||||||
errorMessagePrefix: "vector stores",
|
errorMessagePrefix: "vector stores",
|
||||||
|
|
@ -53,11 +50,11 @@ export default function VectorStoresPage() {
|
||||||
const renderContent = () => {
|
const renderContent = () => {
|
||||||
if (status === "loading") {
|
if (status === "loading") {
|
||||||
return (
|
return (
|
||||||
<div className="space-y-2">
|
<div className="space-y-2">
|
||||||
<Skeleton className="h-8 w-full"/>
|
<Skeleton className="h-8 w-full" />
|
||||||
<Skeleton className="h-4 w-full"/>
|
<Skeleton className="h-4 w-full" />
|
||||||
<Skeleton className="h-4 w-full"/>
|
<Skeleton className="h-4 w-full" />
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -70,72 +67,72 @@ export default function VectorStoresPage() {
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="overflow-auto flex-1 min-h-0">
|
<div className="overflow-auto flex-1 min-h-0">
|
||||||
<Table>
|
<Table>
|
||||||
<TableHeader>
|
<TableHeader>
|
||||||
<TableRow>
|
<TableRow>
|
||||||
<TableHead>ID</TableHead>
|
<TableHead>ID</TableHead>
|
||||||
<TableHead>Name</TableHead>
|
<TableHead>Name</TableHead>
|
||||||
<TableHead>Created</TableHead>
|
<TableHead>Created</TableHead>
|
||||||
<TableHead>Completed</TableHead>
|
<TableHead>Completed</TableHead>
|
||||||
<TableHead>Cancelled</TableHead>
|
<TableHead>Cancelled</TableHead>
|
||||||
<TableHead>Failed</TableHead>
|
<TableHead>Failed</TableHead>
|
||||||
<TableHead>In Progress</TableHead>
|
<TableHead>In Progress</TableHead>
|
||||||
<TableHead>Total</TableHead>
|
<TableHead>Total</TableHead>
|
||||||
<TableHead>Usage Bytes</TableHead>
|
<TableHead>Usage Bytes</TableHead>
|
||||||
<TableHead>Provider ID</TableHead>
|
<TableHead>Provider ID</TableHead>
|
||||||
<TableHead>Provider Vector DB ID</TableHead>
|
<TableHead>Provider Vector DB ID</TableHead>
|
||||||
</TableRow>
|
</TableRow>
|
||||||
</TableHeader>
|
</TableHeader>
|
||||||
<TableBody>
|
<TableBody>
|
||||||
{stores.map((store) => {
|
{stores.map(store => {
|
||||||
const fileCounts = store.file_counts;
|
const fileCounts = store.file_counts;
|
||||||
const metadata = store.metadata || {};
|
const metadata = store.metadata || {};
|
||||||
const providerId = metadata.provider_id ?? "";
|
const providerId = metadata.provider_id ?? "";
|
||||||
const providerDbId = metadata.provider_vector_db_id ?? "";
|
const providerDbId = metadata.provider_vector_db_id ?? "";
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<TableRow
|
<TableRow
|
||||||
key={store.id}
|
key={store.id}
|
||||||
onClick={() => router.push(`/logs/vector-stores/${store.id}`)}
|
onClick={() => router.push(`/logs/vector-stores/${store.id}`)}
|
||||||
className="cursor-pointer hover:bg-muted/50"
|
className="cursor-pointer hover:bg-muted/50"
|
||||||
|
>
|
||||||
|
<TableCell>
|
||||||
|
<Button
|
||||||
|
variant="link"
|
||||||
|
className="p-0 h-auto font-mono text-blue-600 hover:text-blue-800 dark:text-blue-400 dark:hover:text-blue-300"
|
||||||
|
onClick={() =>
|
||||||
|
router.push(`/logs/vector-stores/${store.id}`)
|
||||||
|
}
|
||||||
>
|
>
|
||||||
<TableCell>
|
{store.id}
|
||||||
<Button
|
</Button>
|
||||||
variant="link"
|
</TableCell>
|
||||||
className="p-0 h-auto font-mono text-blue-600 hover:text-blue-800 dark:text-blue-400 dark:hover:text-blue-300"
|
<TableCell>{store.name}</TableCell>
|
||||||
onClick={() =>
|
<TableCell>
|
||||||
router.push(`/logs/vector-stores/${store.id}`)
|
{new Date(store.created_at * 1000).toLocaleString()}
|
||||||
}
|
</TableCell>
|
||||||
>
|
<TableCell>{fileCounts.completed}</TableCell>
|
||||||
{store.id}
|
<TableCell>{fileCounts.cancelled}</TableCell>
|
||||||
</Button>
|
<TableCell>{fileCounts.failed}</TableCell>
|
||||||
</TableCell>
|
<TableCell>{fileCounts.in_progress}</TableCell>
|
||||||
<TableCell>{store.name}</TableCell>
|
<TableCell>{fileCounts.total}</TableCell>
|
||||||
<TableCell>
|
<TableCell>{store.usage_bytes}</TableCell>
|
||||||
{new Date(store.created_at * 1000).toLocaleString()}
|
<TableCell>{providerId}</TableCell>
|
||||||
</TableCell>
|
<TableCell>{providerDbId}</TableCell>
|
||||||
<TableCell>{fileCounts.completed}</TableCell>
|
</TableRow>
|
||||||
<TableCell>{fileCounts.cancelled}</TableCell>
|
);
|
||||||
<TableCell>{fileCounts.failed}</TableCell>
|
})}
|
||||||
<TableCell>{fileCounts.in_progress}</TableCell>
|
</TableBody>
|
||||||
<TableCell>{fileCounts.total}</TableCell>
|
</Table>
|
||||||
<TableCell>{store.usage_bytes}</TableCell>
|
</div>
|
||||||
<TableCell>{providerId}</TableCell>
|
|
||||||
<TableCell>{providerDbId}</TableCell>
|
|
||||||
</TableRow>
|
|
||||||
);
|
|
||||||
})}
|
|
||||||
</TableBody>
|
|
||||||
</Table>
|
|
||||||
</div>
|
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="space-y-4">
|
<div className="space-y-4">
|
||||||
<h1 className="text-2xl font-semibold">Vector Stores</h1>
|
<h1 className="text-2xl font-semibold">Vector Stores</h1>
|
||||||
{renderContent()}
|
{renderContent()}
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ describe("ChatCompletionDetailView", () => {
|
||||||
isLoading={true}
|
isLoading={true}
|
||||||
error={null}
|
error={null}
|
||||||
id="test-id"
|
id="test-id"
|
||||||
/>,
|
/>
|
||||||
);
|
);
|
||||||
// Use the data-slot attribute for Skeletons
|
// Use the data-slot attribute for Skeletons
|
||||||
const skeletons = container.querySelectorAll('[data-slot="skeleton"]');
|
const skeletons = container.querySelectorAll('[data-slot="skeleton"]');
|
||||||
|
|
@ -28,10 +28,10 @@ describe("ChatCompletionDetailView", () => {
|
||||||
isLoading={false}
|
isLoading={false}
|
||||||
error={{ name: "Error", message: "Network Error" }}
|
error={{ name: "Error", message: "Network Error" }}
|
||||||
id="err-id"
|
id="err-id"
|
||||||
/>,
|
/>
|
||||||
);
|
);
|
||||||
expect(
|
expect(
|
||||||
screen.getByText(/Error loading details for ID err-id: Network Error/),
|
screen.getByText(/Error loading details for ID err-id: Network Error/)
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -42,11 +42,11 @@ describe("ChatCompletionDetailView", () => {
|
||||||
isLoading={false}
|
isLoading={false}
|
||||||
error={{ name: "Error", message: "" }}
|
error={{ name: "Error", message: "" }}
|
||||||
id="err-id"
|
id="err-id"
|
||||||
/>,
|
/>
|
||||||
);
|
);
|
||||||
// Use regex to match the error message regardless of whitespace
|
// Use regex to match the error message regardless of whitespace
|
||||||
expect(
|
expect(
|
||||||
screen.getByText(/Error loading details for ID\s*err-id\s*:/),
|
screen.getByText(/Error loading details for ID\s*err-id\s*:/)
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -57,11 +57,11 @@ describe("ChatCompletionDetailView", () => {
|
||||||
isLoading={false}
|
isLoading={false}
|
||||||
error={{} as Error}
|
error={{} as Error}
|
||||||
id="err-id"
|
id="err-id"
|
||||||
/>,
|
/>
|
||||||
);
|
);
|
||||||
// Use regex to match the error message regardless of whitespace
|
// Use regex to match the error message regardless of whitespace
|
||||||
expect(
|
expect(
|
||||||
screen.getByText(/Error loading details for ID\s*err-id\s*:/),
|
screen.getByText(/Error loading details for ID\s*err-id\s*:/)
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -72,10 +72,10 @@ describe("ChatCompletionDetailView", () => {
|
||||||
isLoading={false}
|
isLoading={false}
|
||||||
error={null}
|
error={null}
|
||||||
id="notfound-id"
|
id="notfound-id"
|
||||||
/>,
|
/>
|
||||||
);
|
);
|
||||||
expect(
|
expect(
|
||||||
screen.getByText("No details found for ID: notfound-id."),
|
screen.getByText("No details found for ID: notfound-id.")
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -100,7 +100,7 @@ describe("ChatCompletionDetailView", () => {
|
||||||
isLoading={false}
|
isLoading={false}
|
||||||
error={null}
|
error={null}
|
||||||
id={mockCompletion.id}
|
id={mockCompletion.id}
|
||||||
/>,
|
/>
|
||||||
);
|
);
|
||||||
// Input
|
// Input
|
||||||
expect(screen.getByText("Input")).toBeInTheDocument();
|
expect(screen.getByText("Input")).toBeInTheDocument();
|
||||||
|
|
@ -112,7 +112,7 @@ describe("ChatCompletionDetailView", () => {
|
||||||
expect(screen.getByText("Properties")).toBeInTheDocument();
|
expect(screen.getByText("Properties")).toBeInTheDocument();
|
||||||
expect(screen.getByText("Created:")).toBeInTheDocument();
|
expect(screen.getByText("Created:")).toBeInTheDocument();
|
||||||
expect(
|
expect(
|
||||||
screen.getByText(new Date(1710000000 * 1000).toLocaleString()),
|
screen.getByText(new Date(1710000000 * 1000).toLocaleString())
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
expect(screen.getByText("ID:")).toBeInTheDocument();
|
expect(screen.getByText("ID:")).toBeInTheDocument();
|
||||||
expect(screen.getByText("comp_123")).toBeInTheDocument();
|
expect(screen.getByText("comp_123")).toBeInTheDocument();
|
||||||
|
|
@ -150,7 +150,7 @@ describe("ChatCompletionDetailView", () => {
|
||||||
isLoading={false}
|
isLoading={false}
|
||||||
error={null}
|
error={null}
|
||||||
id={mockCompletion.id}
|
id={mockCompletion.id}
|
||||||
/>,
|
/>
|
||||||
);
|
);
|
||||||
// Output should include the tool call block (should be present twice: input and output)
|
// Output should include the tool call block (should be present twice: input and output)
|
||||||
const toolCallLabels = screen.getAllByText("Tool Call");
|
const toolCallLabels = screen.getAllByText("Tool Call");
|
||||||
|
|
@ -178,13 +178,13 @@ describe("ChatCompletionDetailView", () => {
|
||||||
isLoading={false}
|
isLoading={false}
|
||||||
error={null}
|
error={null}
|
||||||
id={mockCompletion.id}
|
id={mockCompletion.id}
|
||||||
/>,
|
/>
|
||||||
);
|
);
|
||||||
// Input section should be present but empty
|
// Input section should be present but empty
|
||||||
expect(screen.getByText("Input")).toBeInTheDocument();
|
expect(screen.getByText("Input")).toBeInTheDocument();
|
||||||
// Output section should show fallback message
|
// Output section should show fallback message
|
||||||
expect(
|
expect(
|
||||||
screen.getByText("No message found in assistant's choice."),
|
screen.getByText("No message found in assistant's choice.")
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
// Properties should show N/A for finish reason
|
// Properties should show N/A for finish reason
|
||||||
expect(screen.getByText("Finish Reason:")).toBeInTheDocument();
|
expect(screen.getByText("Finish Reason:")).toBeInTheDocument();
|
||||||
|
|
|
||||||
|
|
@ -53,14 +53,14 @@ export function ChatCompletionDetailView({
|
||||||
{completion.choices?.[0]?.message?.tool_calls &&
|
{completion.choices?.[0]?.message?.tool_calls &&
|
||||||
Array.isArray(completion.choices[0].message.tool_calls) &&
|
Array.isArray(completion.choices[0].message.tool_calls) &&
|
||||||
!completion.input_messages?.some(
|
!completion.input_messages?.some(
|
||||||
(im) =>
|
im =>
|
||||||
im.role === "assistant" &&
|
im.role === "assistant" &&
|
||||||
im.tool_calls &&
|
im.tool_calls &&
|
||||||
Array.isArray(im.tool_calls) &&
|
Array.isArray(im.tool_calls) &&
|
||||||
im.tool_calls.length > 0,
|
im.tool_calls.length > 0
|
||||||
)
|
)
|
||||||
? completion.choices[0].message.tool_calls.map(
|
? completion.choices[0].message.tool_calls.map(
|
||||||
(toolCall: any, index: number) => {
|
(toolCall: { function?: { name?: string } }, index: number) => {
|
||||||
const assistantToolCallMessage: ChatMessage = {
|
const assistantToolCallMessage: ChatMessage = {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
tool_calls: [toolCall],
|
tool_calls: [toolCall],
|
||||||
|
|
@ -72,7 +72,7 @@ export function ChatCompletionDetailView({
|
||||||
message={assistantToolCallMessage}
|
message={assistantToolCallMessage}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
},
|
}
|
||||||
)
|
)
|
||||||
: null}
|
: null}
|
||||||
</CardContent>
|
</CardContent>
|
||||||
|
|
@ -89,7 +89,7 @@ export function ChatCompletionDetailView({
|
||||||
/>
|
/>
|
||||||
) : (
|
) : (
|
||||||
<p className="text-gray-500 italic text-sm">
|
<p className="text-gray-500 italic text-sm">
|
||||||
No message found in assistant's choice.
|
No message found in assistant's choice.
|
||||||
</p>
|
</p>
|
||||||
)}
|
)}
|
||||||
</CardContent>
|
</CardContent>
|
||||||
|
|
@ -120,13 +120,18 @@ export function ChatCompletionDetailView({
|
||||||
value={
|
value={
|
||||||
<div>
|
<div>
|
||||||
<ul className="list-disc list-inside pl-4 mt-1">
|
<ul className="list-disc list-inside pl-4 mt-1">
|
||||||
{toolCalls.map((toolCall: any, index: number) => (
|
{toolCalls.map(
|
||||||
<li key={index}>
|
(
|
||||||
<span className="text-gray-900 font-medium">
|
toolCall: { function?: { name?: string } },
|
||||||
{toolCall.function?.name || "N/A"}
|
index: number
|
||||||
</span>
|
) => (
|
||||||
</li>
|
<li key={index}>
|
||||||
))}
|
<span className="text-gray-900 font-medium">
|
||||||
|
{toolCall.function?.name || "N/A"}
|
||||||
|
</span>
|
||||||
|
</li>
|
||||||
|
)
|
||||||
|
)}
|
||||||
</ul>
|
</ul>
|
||||||
</div>
|
</div>
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,7 @@ describe("ChatCompletionsTable", () => {
|
||||||
// Default pass-through implementations
|
// Default pass-through implementations
|
||||||
truncateText.mockImplementation((text: string | undefined) => text);
|
truncateText.mockImplementation((text: string | undefined) => text);
|
||||||
extractTextFromContentPart.mockImplementation((content: unknown) =>
|
extractTextFromContentPart.mockImplementation((content: unknown) =>
|
||||||
typeof content === "string" ? content : "extracted text",
|
typeof content === "string" ? content : "extracted text"
|
||||||
);
|
);
|
||||||
extractDisplayableText.mockImplementation((message: unknown) => {
|
extractDisplayableText.mockImplementation((message: unknown) => {
|
||||||
const msg = message as { content?: string };
|
const msg = message as { content?: string };
|
||||||
|
|
@ -138,7 +138,7 @@ describe("ChatCompletionsTable", () => {
|
||||||
if (row) {
|
if (row) {
|
||||||
fireEvent.click(row);
|
fireEvent.click(row);
|
||||||
expect(mockPush).toHaveBeenCalledWith(
|
expect(mockPush).toHaveBeenCalledWith(
|
||||||
"/logs/chat-completions/completion_123",
|
"/logs/chat-completions/completion_123"
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
throw new Error('Row with "Test prompt" not found for router mock test.');
|
throw new Error('Row with "Test prompt" not found for router mock test.');
|
||||||
|
|
@ -162,7 +162,7 @@ describe("ChatCompletionsTable", () => {
|
||||||
expect(tableCaption).toBeInTheDocument();
|
expect(tableCaption).toBeInTheDocument();
|
||||||
if (tableCaption) {
|
if (tableCaption) {
|
||||||
const captionSkeleton = tableCaption.querySelector(
|
const captionSkeleton = tableCaption.querySelector(
|
||||||
'[data-slot="skeleton"]',
|
'[data-slot="skeleton"]'
|
||||||
);
|
);
|
||||||
expect(captionSkeleton).toBeInTheDocument();
|
expect(captionSkeleton).toBeInTheDocument();
|
||||||
}
|
}
|
||||||
|
|
@ -172,7 +172,7 @@ describe("ChatCompletionsTable", () => {
|
||||||
expect(tableBody).toBeInTheDocument();
|
expect(tableBody).toBeInTheDocument();
|
||||||
if (tableBody) {
|
if (tableBody) {
|
||||||
const bodySkeletons = tableBody.querySelectorAll(
|
const bodySkeletons = tableBody.querySelectorAll(
|
||||||
'[data-slot="skeleton"]',
|
'[data-slot="skeleton"]'
|
||||||
);
|
);
|
||||||
expect(bodySkeletons.length).toBeGreaterThan(0);
|
expect(bodySkeletons.length).toBeGreaterThan(0);
|
||||||
}
|
}
|
||||||
|
|
@ -192,14 +192,14 @@ describe("ChatCompletionsTable", () => {
|
||||||
|
|
||||||
render(<ChatCompletionsTable {...defaultProps} />);
|
render(<ChatCompletionsTable {...defaultProps} />);
|
||||||
expect(
|
expect(
|
||||||
screen.getByText("Unable to load chat completions"),
|
screen.getByText("Unable to load chat completions")
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
expect(screen.getByText(errorMessage)).toBeInTheDocument();
|
expect(screen.getByText(errorMessage)).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
|
|
||||||
test.each([{ name: "Error", message: "" }, {}])(
|
test.each([{ name: "Error", message: "" }, {}])(
|
||||||
"renders default error message when error has no message",
|
"renders default error message when error has no message",
|
||||||
(errorObject) => {
|
errorObject => {
|
||||||
mockedUsePagination.mockReturnValue({
|
mockedUsePagination.mockReturnValue({
|
||||||
data: [],
|
data: [],
|
||||||
status: "error",
|
status: "error",
|
||||||
|
|
@ -210,14 +210,14 @@ describe("ChatCompletionsTable", () => {
|
||||||
|
|
||||||
render(<ChatCompletionsTable {...defaultProps} />);
|
render(<ChatCompletionsTable {...defaultProps} />);
|
||||||
expect(
|
expect(
|
||||||
screen.getByText("Unable to load chat completions"),
|
screen.getByText("Unable to load chat completions")
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
expect(
|
expect(
|
||||||
screen.getByText(
|
screen.getByText(
|
||||||
"An unexpected error occurred while loading the data.",
|
"An unexpected error occurred while loading the data."
|
||||||
),
|
)
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
},
|
}
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -225,7 +225,7 @@ describe("ChatCompletionsTable", () => {
|
||||||
test('renders "No chat completions found." and no table when data array is empty', () => {
|
test('renders "No chat completions found." and no table when data array is empty', () => {
|
||||||
render(<ChatCompletionsTable {...defaultProps} />);
|
render(<ChatCompletionsTable {...defaultProps} />);
|
||||||
expect(
|
expect(
|
||||||
screen.getByText("No chat completions found."),
|
screen.getByText("No chat completions found.")
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
|
|
||||||
// Ensure that the table structure is NOT rendered in the empty state
|
// Ensure that the table structure is NOT rendered in the empty state
|
||||||
|
|
@ -292,7 +292,7 @@ describe("ChatCompletionsTable", () => {
|
||||||
|
|
||||||
// Table caption
|
// Table caption
|
||||||
expect(
|
expect(
|
||||||
screen.getByText("A list of your recent chat completions."),
|
screen.getByText("A list of your recent chat completions.")
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
|
|
||||||
// Table headers
|
// Table headers
|
||||||
|
|
@ -306,14 +306,14 @@ describe("ChatCompletionsTable", () => {
|
||||||
expect(screen.getByText("Test output")).toBeInTheDocument();
|
expect(screen.getByText("Test output")).toBeInTheDocument();
|
||||||
expect(screen.getByText("llama-test-model")).toBeInTheDocument();
|
expect(screen.getByText("llama-test-model")).toBeInTheDocument();
|
||||||
expect(
|
expect(
|
||||||
screen.getByText(new Date(1710000000 * 1000).toLocaleString()),
|
screen.getByText(new Date(1710000000 * 1000).toLocaleString())
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
|
|
||||||
expect(screen.getByText("Another input")).toBeInTheDocument();
|
expect(screen.getByText("Another input")).toBeInTheDocument();
|
||||||
expect(screen.getByText("Another output")).toBeInTheDocument();
|
expect(screen.getByText("Another output")).toBeInTheDocument();
|
||||||
expect(screen.getByText("llama-another-model")).toBeInTheDocument();
|
expect(screen.getByText("llama-another-model")).toBeInTheDocument();
|
||||||
expect(
|
expect(
|
||||||
screen.getByText(new Date(1710001000 * 1000).toLocaleString()),
|
screen.getByText(new Date(1710001000 * 1000).toLocaleString())
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
@ -328,7 +328,7 @@ describe("ChatCompletionsTable", () => {
|
||||||
return typeof text === "string" && text.length > effectiveMaxLength
|
return typeof text === "string" && text.length > effectiveMaxLength
|
||||||
? text.slice(0, effectiveMaxLength) + "..."
|
? text.slice(0, effectiveMaxLength) + "..."
|
||||||
: text;
|
: text;
|
||||||
},
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
const longInput =
|
const longInput =
|
||||||
|
|
@ -368,7 +368,7 @@ describe("ChatCompletionsTable", () => {
|
||||||
|
|
||||||
// The truncated text should be present for both input and output
|
// The truncated text should be present for both input and output
|
||||||
const truncatedTexts = screen.getAllByText(
|
const truncatedTexts = screen.getAllByText(
|
||||||
longInput.slice(0, 10) + "...",
|
longInput.slice(0, 10) + "..."
|
||||||
);
|
);
|
||||||
expect(truncatedTexts.length).toBe(2); // one for input, one for output
|
expect(truncatedTexts.length).toBe(2); // one for input, one for output
|
||||||
});
|
});
|
||||||
|
|
@ -420,7 +420,7 @@ describe("ChatCompletionsTable", () => {
|
||||||
// Verify the extracted text appears in the table
|
// Verify the extracted text appears in the table
|
||||||
expect(screen.getByText("Extracted input")).toBeInTheDocument();
|
expect(screen.getByText("Extracted input")).toBeInTheDocument();
|
||||||
expect(
|
expect(
|
||||||
screen.getByText("Extracted output from assistant"),
|
screen.getByText("Extracted output from assistant")
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import {
|
||||||
UsePaginationOptions,
|
UsePaginationOptions,
|
||||||
ListChatCompletionsResponse,
|
ListChatCompletionsResponse,
|
||||||
} from "@/lib/types";
|
} from "@/lib/types";
|
||||||
|
import { ListChatCompletionsParams } from "@/lib/llama-stack-client";
|
||||||
import { LogsTable, LogTableRow } from "@/components/logs/logs-table";
|
import { LogsTable, LogTableRow } from "@/components/logs/logs-table";
|
||||||
import {
|
import {
|
||||||
extractTextFromContentPart,
|
extractTextFromContentPart,
|
||||||
|
|
@ -38,14 +39,14 @@ export function ChatCompletionsTable({
|
||||||
limit: number;
|
limit: number;
|
||||||
model?: string;
|
model?: string;
|
||||||
order?: string;
|
order?: string;
|
||||||
},
|
}
|
||||||
) => {
|
) => {
|
||||||
const response = await client.chat.completions.list({
|
const response = await client.chat.completions.list({
|
||||||
after: params.after,
|
after: params.after,
|
||||||
limit: params.limit,
|
limit: params.limit,
|
||||||
...(params.model && { model: params.model }),
|
...(params.model && { model: params.model }),
|
||||||
...(params.order && { order: params.order }),
|
...(params.order && { order: params.order }),
|
||||||
} as any);
|
} as ListChatCompletionsParams);
|
||||||
|
|
||||||
return response as ListChatCompletionsResponse;
|
return response as ListChatCompletionsResponse;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -37,21 +37,26 @@ export function ChatMessageItem({ message }: ChatMessageItemProps) {
|
||||||
) {
|
) {
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
{message.tool_calls.map((toolCall: any, index: number) => {
|
{message.tool_calls.map(
|
||||||
const formattedToolCall = formatToolCallToString(toolCall);
|
(
|
||||||
const toolCallContent = (
|
toolCall: { function?: { name?: string; arguments?: unknown } },
|
||||||
<ToolCallBlock>
|
index: number
|
||||||
{formattedToolCall || "Error: Could not display tool call"}
|
) => {
|
||||||
</ToolCallBlock>
|
const formattedToolCall = formatToolCallToString(toolCall);
|
||||||
);
|
const toolCallContent = (
|
||||||
return (
|
<ToolCallBlock>
|
||||||
<MessageBlock
|
{formattedToolCall || "Error: Could not display tool call"}
|
||||||
key={index}
|
</ToolCallBlock>
|
||||||
label="Tool Call"
|
);
|
||||||
content={toolCallContent}
|
return (
|
||||||
/>
|
<MessageBlock
|
||||||
);
|
key={index}
|
||||||
})}
|
label="Tool Call"
|
||||||
|
content={toolCallContent}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
)}
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,18 @@
|
||||||
"use client"
|
"use client";
|
||||||
|
|
||||||
import React, { useMemo, useState } from "react"
|
import React, { useMemo, useState } from "react";
|
||||||
import { cva, type VariantProps } from "class-variance-authority"
|
import { cva, type VariantProps } from "class-variance-authority";
|
||||||
import { motion } from "framer-motion"
|
import { motion } from "framer-motion";
|
||||||
import { Ban, ChevronRight, Code2, Loader2, Terminal } from "lucide-react"
|
import { Ban, ChevronRight, Code2, Loader2, Terminal } from "lucide-react";
|
||||||
|
|
||||||
import { cn } from "@/lib/utils"
|
import { cn } from "@/lib/utils";
|
||||||
import {
|
import {
|
||||||
Collapsible,
|
Collapsible,
|
||||||
CollapsibleContent,
|
CollapsibleContent,
|
||||||
CollapsibleTrigger,
|
CollapsibleTrigger,
|
||||||
} from "@/components/ui/collapsible"
|
} from "@/components/ui/collapsible";
|
||||||
import { FilePreview } from "@/components/ui/file-preview"
|
import { FilePreview } from "@/components/ui/file-preview";
|
||||||
import { MarkdownRenderer } from "@/components/chat-playground/markdown-renderer"
|
import { MarkdownRenderer } from "@/components/chat-playground/markdown-renderer";
|
||||||
|
|
||||||
const chatBubbleVariants = cva(
|
const chatBubbleVariants = cva(
|
||||||
"group/message relative break-words rounded-lg p-3 text-sm sm:max-w-[70%]",
|
"group/message relative break-words rounded-lg p-3 text-sm sm:max-w-[70%]",
|
||||||
|
|
@ -52,66 +52,66 @@ const chatBubbleVariants = cva(
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
)
|
);
|
||||||
|
|
||||||
type Animation = VariantProps<typeof chatBubbleVariants>["animation"]
|
type Animation = VariantProps<typeof chatBubbleVariants>["animation"];
|
||||||
|
|
||||||
interface Attachment {
|
interface Attachment {
|
||||||
name?: string
|
name?: string;
|
||||||
contentType?: string
|
contentType?: string;
|
||||||
url: string
|
url: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface PartialToolCall {
|
interface PartialToolCall {
|
||||||
state: "partial-call"
|
state: "partial-call";
|
||||||
toolName: string
|
toolName: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface ToolCall {
|
interface ToolCall {
|
||||||
state: "call"
|
state: "call";
|
||||||
toolName: string
|
toolName: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface ToolResult {
|
interface ToolResult {
|
||||||
state: "result"
|
state: "result";
|
||||||
toolName: string
|
toolName: string;
|
||||||
result: {
|
result: {
|
||||||
__cancelled?: boolean
|
__cancelled?: boolean;
|
||||||
[key: string]: any
|
[key: string]: unknown;
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
type ToolInvocation = PartialToolCall | ToolCall | ToolResult
|
type ToolInvocation = PartialToolCall | ToolCall | ToolResult;
|
||||||
|
|
||||||
interface ReasoningPart {
|
interface ReasoningPart {
|
||||||
type: "reasoning"
|
type: "reasoning";
|
||||||
reasoning: string
|
reasoning: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface ToolInvocationPart {
|
interface ToolInvocationPart {
|
||||||
type: "tool-invocation"
|
type: "tool-invocation";
|
||||||
toolInvocation: ToolInvocation
|
toolInvocation: ToolInvocation;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface TextPart {
|
interface TextPart {
|
||||||
type: "text"
|
type: "text";
|
||||||
text: string
|
text: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
// For compatibility with AI SDK types, not used
|
// For compatibility with AI SDK types, not used
|
||||||
interface SourcePart {
|
interface SourcePart {
|
||||||
type: "source"
|
type: "source";
|
||||||
source?: any
|
source?: unknown;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface FilePart {
|
interface FilePart {
|
||||||
type: "file"
|
type: "file";
|
||||||
mimeType: string
|
mimeType: string;
|
||||||
data: string
|
data: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface StepStartPart {
|
interface StepStartPart {
|
||||||
type: "step-start"
|
type: "step-start";
|
||||||
}
|
}
|
||||||
|
|
||||||
type MessagePart =
|
type MessagePart =
|
||||||
|
|
@ -120,22 +120,22 @@ type MessagePart =
|
||||||
| ToolInvocationPart
|
| ToolInvocationPart
|
||||||
| SourcePart
|
| SourcePart
|
||||||
| FilePart
|
| FilePart
|
||||||
| StepStartPart
|
| StepStartPart;
|
||||||
|
|
||||||
export interface Message {
|
export interface Message {
|
||||||
id: string
|
id: string;
|
||||||
role: "user" | "assistant" | (string & {})
|
role: "user" | "assistant" | (string & {});
|
||||||
content: string
|
content: string;
|
||||||
createdAt?: Date
|
createdAt?: Date;
|
||||||
experimental_attachments?: Attachment[]
|
experimental_attachments?: Attachment[];
|
||||||
toolInvocations?: ToolInvocation[]
|
toolInvocations?: ToolInvocation[];
|
||||||
parts?: MessagePart[]
|
parts?: MessagePart[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ChatMessageProps extends Message {
|
export interface ChatMessageProps extends Message {
|
||||||
showTimeStamp?: boolean
|
showTimeStamp?: boolean;
|
||||||
animation?: Animation
|
animation?: Animation;
|
||||||
actions?: React.ReactNode
|
actions?: React.ReactNode;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const ChatMessage: React.FC<ChatMessageProps> = ({
|
export const ChatMessage: React.FC<ChatMessageProps> = ({
|
||||||
|
|
@ -150,21 +150,21 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
|
||||||
parts,
|
parts,
|
||||||
}) => {
|
}) => {
|
||||||
const files = useMemo(() => {
|
const files = useMemo(() => {
|
||||||
return experimental_attachments?.map((attachment) => {
|
return experimental_attachments?.map(attachment => {
|
||||||
const dataArray = dataUrlToUint8Array(attachment.url)
|
const dataArray = dataUrlToUint8Array(attachment.url);
|
||||||
const file = new File([dataArray], attachment.name ?? "Unknown", {
|
const file = new File([dataArray], attachment.name ?? "Unknown", {
|
||||||
type: attachment.contentType,
|
type: attachment.contentType,
|
||||||
})
|
});
|
||||||
return file
|
return file;
|
||||||
})
|
});
|
||||||
}, [experimental_attachments])
|
}, [experimental_attachments]);
|
||||||
|
|
||||||
const isUser = role === "user"
|
const isUser = role === "user";
|
||||||
|
|
||||||
const formattedTime = createdAt?.toLocaleTimeString("en-US", {
|
const formattedTime = createdAt?.toLocaleTimeString("en-US", {
|
||||||
hour: "2-digit",
|
hour: "2-digit",
|
||||||
minute: "2-digit",
|
minute: "2-digit",
|
||||||
})
|
});
|
||||||
|
|
||||||
if (isUser) {
|
if (isUser) {
|
||||||
return (
|
return (
|
||||||
|
|
@ -174,7 +174,7 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
|
||||||
{files ? (
|
{files ? (
|
||||||
<div className="mb-1 flex flex-wrap gap-2">
|
<div className="mb-1 flex flex-wrap gap-2">
|
||||||
{files.map((file, index) => {
|
{files.map((file, index) => {
|
||||||
return <FilePreview file={file} key={index} />
|
return <FilePreview file={file} key={index} />;
|
||||||
})}
|
})}
|
||||||
</div>
|
</div>
|
||||||
) : null}
|
) : null}
|
||||||
|
|
@ -195,7 +195,7 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
|
||||||
</time>
|
</time>
|
||||||
) : null}
|
) : null}
|
||||||
</div>
|
</div>
|
||||||
)
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (parts && parts.length > 0) {
|
if (parts && parts.length > 0) {
|
||||||
|
|
@ -230,23 +230,23 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
|
||||||
</time>
|
</time>
|
||||||
) : null}
|
) : null}
|
||||||
</div>
|
</div>
|
||||||
)
|
);
|
||||||
} else if (part.type === "reasoning") {
|
} else if (part.type === "reasoning") {
|
||||||
return <ReasoningBlock key={`reasoning-${index}`} part={part} />
|
return <ReasoningBlock key={`reasoning-${index}`} part={part} />;
|
||||||
} else if (part.type === "tool-invocation") {
|
} else if (part.type === "tool-invocation") {
|
||||||
return (
|
return (
|
||||||
<ToolCall
|
<ToolCall
|
||||||
key={`tool-${index}`}
|
key={`tool-${index}`}
|
||||||
toolInvocations={[part.toolInvocation]}
|
toolInvocations={[part.toolInvocation]}
|
||||||
/>
|
/>
|
||||||
)
|
);
|
||||||
}
|
}
|
||||||
return null
|
return null;
|
||||||
})
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (toolInvocations && toolInvocations.length > 0) {
|
if (toolInvocations && toolInvocations.length > 0) {
|
||||||
return <ToolCall toolInvocations={toolInvocations} />
|
return <ToolCall toolInvocations={toolInvocations} />;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|
@ -272,17 +272,17 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
|
||||||
</time>
|
</time>
|
||||||
) : null}
|
) : null}
|
||||||
</div>
|
</div>
|
||||||
)
|
);
|
||||||
}
|
};
|
||||||
|
|
||||||
function dataUrlToUint8Array(data: string) {
|
function dataUrlToUint8Array(data: string) {
|
||||||
const base64 = data.split(",")[1]
|
const base64 = data.split(",")[1];
|
||||||
const buf = Buffer.from(base64, "base64")
|
const buf = Buffer.from(base64, "base64");
|
||||||
return new Uint8Array(buf)
|
return new Uint8Array(buf);
|
||||||
}
|
}
|
||||||
|
|
||||||
const ReasoningBlock = ({ part }: { part: ReasoningPart }) => {
|
const ReasoningBlock = ({ part }: { part: ReasoningPart }) => {
|
||||||
const [isOpen, setIsOpen] = useState(false)
|
const [isOpen, setIsOpen] = useState(false);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="mb-2 flex flex-col items-start sm:max-w-[70%]">
|
<div className="mb-2 flex flex-col items-start sm:max-w-[70%]">
|
||||||
|
|
@ -319,20 +319,20 @@ const ReasoningBlock = ({ part }: { part: ReasoningPart }) => {
|
||||||
</CollapsibleContent>
|
</CollapsibleContent>
|
||||||
</Collapsible>
|
</Collapsible>
|
||||||
</div>
|
</div>
|
||||||
)
|
);
|
||||||
}
|
};
|
||||||
|
|
||||||
function ToolCall({
|
function ToolCall({
|
||||||
toolInvocations,
|
toolInvocations,
|
||||||
}: Pick<ChatMessageProps, "toolInvocations">) {
|
}: Pick<ChatMessageProps, "toolInvocations">) {
|
||||||
if (!toolInvocations?.length) return null
|
if (!toolInvocations?.length) return null;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-col items-start gap-2">
|
<div className="flex flex-col items-start gap-2">
|
||||||
{toolInvocations.map((invocation, index) => {
|
{toolInvocations.map((invocation, index) => {
|
||||||
const isCancelled =
|
const isCancelled =
|
||||||
invocation.state === "result" &&
|
invocation.state === "result" &&
|
||||||
invocation.result.__cancelled === true
|
invocation.result.__cancelled === true;
|
||||||
|
|
||||||
if (isCancelled) {
|
if (isCancelled) {
|
||||||
return (
|
return (
|
||||||
|
|
@ -350,7 +350,7 @@ function ToolCall({
|
||||||
</span>
|
</span>
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
)
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (invocation.state) {
|
switch (invocation.state) {
|
||||||
|
|
@ -373,7 +373,7 @@ function ToolCall({
|
||||||
</span>
|
</span>
|
||||||
<Loader2 className="h-3 w-3 animate-spin" />
|
<Loader2 className="h-3 w-3 animate-spin" />
|
||||||
</div>
|
</div>
|
||||||
)
|
);
|
||||||
case "result":
|
case "result":
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
|
|
@ -395,11 +395,11 @@ function ToolCall({
|
||||||
{JSON.stringify(invocation.result, null, 2)}
|
{JSON.stringify(invocation.result, null, 2)}
|
||||||
</pre>
|
</pre>
|
||||||
</div>
|
</div>
|
||||||
)
|
);
|
||||||
default:
|
default:
|
||||||
return null
|
return null;
|
||||||
}
|
}
|
||||||
})}
|
})}
|
||||||
</div>
|
</div>
|
||||||
)
|
);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
"use client"
|
"use client";
|
||||||
|
|
||||||
import {
|
import {
|
||||||
forwardRef,
|
forwardRef,
|
||||||
|
|
@ -6,48 +6,48 @@ import {
|
||||||
useRef,
|
useRef,
|
||||||
useState,
|
useState,
|
||||||
type ReactElement,
|
type ReactElement,
|
||||||
} from "react"
|
} from "react";
|
||||||
import { ArrowDown, ThumbsDown, ThumbsUp } from "lucide-react"
|
import { ArrowDown, ThumbsDown, ThumbsUp } from "lucide-react";
|
||||||
|
|
||||||
import { cn } from "@/lib/utils"
|
import { cn } from "@/lib/utils";
|
||||||
import { useAutoScroll } from "@/hooks/use-auto-scroll"
|
import { useAutoScroll } from "@/hooks/use-auto-scroll";
|
||||||
import { Button } from "@/components/ui/button"
|
import { Button } from "@/components/ui/button";
|
||||||
import { type Message } from "@/components/chat-playground/chat-message"
|
import { type Message } from "@/components/chat-playground/chat-message";
|
||||||
import { CopyButton } from "@/components/ui/copy-button"
|
import { CopyButton } from "@/components/ui/copy-button";
|
||||||
import { MessageInput } from "@/components/chat-playground/message-input"
|
import { MessageInput } from "@/components/chat-playground/message-input";
|
||||||
import { MessageList } from "@/components/chat-playground/message-list"
|
import { MessageList } from "@/components/chat-playground/message-list";
|
||||||
import { PromptSuggestions } from "@/components/chat-playground/prompt-suggestions"
|
import { PromptSuggestions } from "@/components/chat-playground/prompt-suggestions";
|
||||||
|
|
||||||
interface ChatPropsBase {
|
interface ChatPropsBase {
|
||||||
handleSubmit: (
|
handleSubmit: (
|
||||||
event?: { preventDefault?: () => void },
|
event?: { preventDefault?: () => void },
|
||||||
options?: { experimental_attachments?: FileList }
|
options?: { experimental_attachments?: FileList }
|
||||||
) => void
|
) => void;
|
||||||
messages: Array<Message>
|
messages: Array<Message>;
|
||||||
input: string
|
input: string;
|
||||||
className?: string
|
className?: string;
|
||||||
handleInputChange: React.ChangeEventHandler<HTMLTextAreaElement>
|
handleInputChange: React.ChangeEventHandler<HTMLTextAreaElement>;
|
||||||
isGenerating: boolean
|
isGenerating: boolean;
|
||||||
stop?: () => void
|
stop?: () => void;
|
||||||
onRateResponse?: (
|
onRateResponse?: (
|
||||||
messageId: string,
|
messageId: string,
|
||||||
rating: "thumbs-up" | "thumbs-down"
|
rating: "thumbs-up" | "thumbs-down"
|
||||||
) => void
|
) => void;
|
||||||
setMessages?: (messages: any[]) => void
|
setMessages?: (messages: Message[]) => void;
|
||||||
transcribeAudio?: (blob: Blob) => Promise<string>
|
transcribeAudio?: (blob: Blob) => Promise<string>;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface ChatPropsWithoutSuggestions extends ChatPropsBase {
|
interface ChatPropsWithoutSuggestions extends ChatPropsBase {
|
||||||
append?: never
|
append?: never;
|
||||||
suggestions?: never
|
suggestions?: never;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface ChatPropsWithSuggestions extends ChatPropsBase {
|
interface ChatPropsWithSuggestions extends ChatPropsBase {
|
||||||
append: (message: { role: "user"; content: string }) => void
|
append: (message: { role: "user"; content: string }) => void;
|
||||||
suggestions: string[]
|
suggestions: string[];
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatProps = ChatPropsWithoutSuggestions | ChatPropsWithSuggestions
|
type ChatProps = ChatPropsWithoutSuggestions | ChatPropsWithSuggestions;
|
||||||
|
|
||||||
export function Chat({
|
export function Chat({
|
||||||
messages,
|
messages,
|
||||||
|
|
@ -63,34 +63,34 @@ export function Chat({
|
||||||
setMessages,
|
setMessages,
|
||||||
transcribeAudio,
|
transcribeAudio,
|
||||||
}: ChatProps) {
|
}: ChatProps) {
|
||||||
const lastMessage = messages.at(-1)
|
const lastMessage = messages.at(-1);
|
||||||
const isEmpty = messages.length === 0
|
const isEmpty = messages.length === 0;
|
||||||
const isTyping = lastMessage?.role === "user"
|
const isTyping = lastMessage?.role === "user";
|
||||||
|
|
||||||
const messagesRef = useRef(messages)
|
const messagesRef = useRef(messages);
|
||||||
messagesRef.current = messages
|
messagesRef.current = messages;
|
||||||
|
|
||||||
// Enhanced stop function that marks pending tool calls as cancelled
|
// Enhanced stop function that marks pending tool calls as cancelled
|
||||||
const handleStop = useCallback(() => {
|
const handleStop = useCallback(() => {
|
||||||
stop?.()
|
stop?.();
|
||||||
|
|
||||||
if (!setMessages) return
|
if (!setMessages) return;
|
||||||
|
|
||||||
const latestMessages = [...messagesRef.current]
|
const latestMessages = [...messagesRef.current];
|
||||||
const lastAssistantMessage = latestMessages.findLast(
|
const lastAssistantMessage = latestMessages.findLast(
|
||||||
(m) => m.role === "assistant"
|
m => m.role === "assistant"
|
||||||
)
|
);
|
||||||
|
|
||||||
if (!lastAssistantMessage) return
|
if (!lastAssistantMessage) return;
|
||||||
|
|
||||||
let needsUpdate = false
|
let needsUpdate = false;
|
||||||
let updatedMessage = { ...lastAssistantMessage }
|
let updatedMessage = { ...lastAssistantMessage };
|
||||||
|
|
||||||
if (lastAssistantMessage.toolInvocations) {
|
if (lastAssistantMessage.toolInvocations) {
|
||||||
const updatedToolInvocations = lastAssistantMessage.toolInvocations.map(
|
const updatedToolInvocations = lastAssistantMessage.toolInvocations.map(
|
||||||
(toolInvocation) => {
|
toolInvocation => {
|
||||||
if (toolInvocation.state === "call") {
|
if (toolInvocation.state === "call") {
|
||||||
needsUpdate = true
|
needsUpdate = true;
|
||||||
return {
|
return {
|
||||||
...toolInvocation,
|
...toolInvocation,
|
||||||
state: "result",
|
state: "result",
|
||||||
|
|
@ -98,61 +98,66 @@ export function Chat({
|
||||||
content: "Tool execution was cancelled",
|
content: "Tool execution was cancelled",
|
||||||
__cancelled: true, // Special marker to indicate cancellation
|
__cancelled: true, // Special marker to indicate cancellation
|
||||||
},
|
},
|
||||||
} as const
|
} as const;
|
||||||
}
|
}
|
||||||
return toolInvocation
|
return toolInvocation;
|
||||||
}
|
}
|
||||||
)
|
);
|
||||||
|
|
||||||
if (needsUpdate) {
|
if (needsUpdate) {
|
||||||
updatedMessage = {
|
updatedMessage = {
|
||||||
...updatedMessage,
|
...updatedMessage,
|
||||||
toolInvocations: updatedToolInvocations,
|
toolInvocations: updatedToolInvocations,
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (lastAssistantMessage.parts && lastAssistantMessage.parts.length > 0) {
|
if (lastAssistantMessage.parts && lastAssistantMessage.parts.length > 0) {
|
||||||
const updatedParts = lastAssistantMessage.parts.map((part: any) => {
|
const updatedParts = lastAssistantMessage.parts.map(
|
||||||
if (
|
(part: {
|
||||||
part.type === "tool-invocation" &&
|
type: string;
|
||||||
part.toolInvocation &&
|
toolInvocation?: { state: string; toolName: string };
|
||||||
part.toolInvocation.state === "call"
|
}) => {
|
||||||
) {
|
if (
|
||||||
needsUpdate = true
|
part.type === "tool-invocation" &&
|
||||||
return {
|
part.toolInvocation &&
|
||||||
...part,
|
part.toolInvocation.state === "call"
|
||||||
toolInvocation: {
|
) {
|
||||||
...part.toolInvocation,
|
needsUpdate = true;
|
||||||
state: "result",
|
return {
|
||||||
result: {
|
...part,
|
||||||
content: "Tool execution was cancelled",
|
toolInvocation: {
|
||||||
__cancelled: true,
|
...part.toolInvocation,
|
||||||
|
state: "result",
|
||||||
|
result: {
|
||||||
|
content: "Tool execution was cancelled",
|
||||||
|
__cancelled: true,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
};
|
||||||
}
|
}
|
||||||
|
return part;
|
||||||
}
|
}
|
||||||
return part
|
);
|
||||||
})
|
|
||||||
|
|
||||||
if (needsUpdate) {
|
if (needsUpdate) {
|
||||||
updatedMessage = {
|
updatedMessage = {
|
||||||
...updatedMessage,
|
...updatedMessage,
|
||||||
parts: updatedParts,
|
parts: updatedParts,
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (needsUpdate) {
|
if (needsUpdate) {
|
||||||
const messageIndex = latestMessages.findIndex(
|
const messageIndex = latestMessages.findIndex(
|
||||||
(m) => m.id === lastAssistantMessage.id
|
m => m.id === lastAssistantMessage.id
|
||||||
)
|
);
|
||||||
if (messageIndex !== -1) {
|
if (messageIndex !== -1) {
|
||||||
latestMessages[messageIndex] = updatedMessage
|
latestMessages[messageIndex] = updatedMessage;
|
||||||
setMessages(latestMessages)
|
setMessages(latestMessages);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}, [stop, setMessages, messagesRef])
|
}, [stop, setMessages, messagesRef]);
|
||||||
|
|
||||||
const messageOptions = useCallback(
|
const messageOptions = useCallback(
|
||||||
(message: Message) => ({
|
(message: Message) => ({
|
||||||
|
|
@ -189,7 +194,7 @@ export function Chat({
|
||||||
),
|
),
|
||||||
}),
|
}),
|
||||||
[onRateResponse]
|
[onRateResponse]
|
||||||
)
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ChatContainer className={className}>
|
<ChatContainer className={className}>
|
||||||
|
|
@ -237,15 +242,15 @@ export function Chat({
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</ChatContainer>
|
</ChatContainer>
|
||||||
)
|
);
|
||||||
}
|
}
|
||||||
Chat.displayName = "Chat"
|
Chat.displayName = "Chat";
|
||||||
|
|
||||||
export function ChatMessages({
|
export function ChatMessages({
|
||||||
messages,
|
messages,
|
||||||
children,
|
children,
|
||||||
}: React.PropsWithChildren<{
|
}: React.PropsWithChildren<{
|
||||||
messages: Message[]
|
messages: Message[];
|
||||||
}>) {
|
}>) {
|
||||||
const {
|
const {
|
||||||
containerRef,
|
containerRef,
|
||||||
|
|
@ -253,7 +258,7 @@ export function ChatMessages({
|
||||||
handleScroll,
|
handleScroll,
|
||||||
shouldAutoScroll,
|
shouldAutoScroll,
|
||||||
handleTouchStart,
|
handleTouchStart,
|
||||||
} = useAutoScroll([messages])
|
} = useAutoScroll([messages]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
|
|
@ -281,7 +286,7 @@ export function ChatMessages({
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
)
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
export const ChatContainer = forwardRef<
|
export const ChatContainer = forwardRef<
|
||||||
|
|
@ -294,56 +299,56 @@ export const ChatContainer = forwardRef<
|
||||||
className={cn("flex flex-col max-h-full w-full", className)}
|
className={cn("flex flex-col max-h-full w-full", className)}
|
||||||
{...props}
|
{...props}
|
||||||
/>
|
/>
|
||||||
)
|
);
|
||||||
})
|
});
|
||||||
ChatContainer.displayName = "ChatContainer"
|
ChatContainer.displayName = "ChatContainer";
|
||||||
|
|
||||||
interface ChatFormProps {
|
interface ChatFormProps {
|
||||||
className?: string
|
className?: string;
|
||||||
isPending: boolean
|
isPending: boolean;
|
||||||
handleSubmit: (
|
handleSubmit: (
|
||||||
event?: { preventDefault?: () => void },
|
event?: { preventDefault?: () => void },
|
||||||
options?: { experimental_attachments?: FileList }
|
options?: { experimental_attachments?: FileList }
|
||||||
) => void
|
) => void;
|
||||||
children: (props: {
|
children: (props: {
|
||||||
files: File[] | null
|
files: File[] | null;
|
||||||
setFiles: React.Dispatch<React.SetStateAction<File[] | null>>
|
setFiles: React.Dispatch<React.SetStateAction<File[] | null>>;
|
||||||
}) => ReactElement
|
}) => ReactElement;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const ChatForm = forwardRef<HTMLFormElement, ChatFormProps>(
|
export const ChatForm = forwardRef<HTMLFormElement, ChatFormProps>(
|
||||||
({ children, handleSubmit, isPending, className }, ref) => {
|
({ children, handleSubmit, isPending, className }, ref) => {
|
||||||
const [files, setFiles] = useState<File[] | null>(null)
|
const [files, setFiles] = useState<File[] | null>(null);
|
||||||
|
|
||||||
const onSubmit = (event: React.FormEvent) => {
|
const onSubmit = (event: React.FormEvent) => {
|
||||||
// if (isPending) {
|
if (isPending) {
|
||||||
// event.preventDefault()
|
event.preventDefault();
|
||||||
// return
|
return;
|
||||||
// }
|
|
||||||
|
|
||||||
if (!files) {
|
|
||||||
handleSubmit(event)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const fileList = createFileList(files)
|
if (!files) {
|
||||||
handleSubmit(event, { experimental_attachments: fileList })
|
handleSubmit(event);
|
||||||
setFiles(null)
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const fileList = createFileList(files);
|
||||||
|
handleSubmit(event, { experimental_attachments: fileList });
|
||||||
|
setFiles(null);
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<form ref={ref} onSubmit={onSubmit} className={className}>
|
<form ref={ref} onSubmit={onSubmit} className={className}>
|
||||||
{children({ files, setFiles })}
|
{children({ files, setFiles })}
|
||||||
</form>
|
</form>
|
||||||
)
|
);
|
||||||
}
|
}
|
||||||
)
|
);
|
||||||
ChatForm.displayName = "ChatForm"
|
ChatForm.displayName = "ChatForm";
|
||||||
|
|
||||||
function createFileList(files: File[] | FileList): FileList {
|
function createFileList(files: File[] | FileList): FileList {
|
||||||
const dataTransfer = new DataTransfer()
|
const dataTransfer = new DataTransfer();
|
||||||
for (const file of Array.from(files)) {
|
for (const file of Array.from(files)) {
|
||||||
dataTransfer.items.add(file)
|
dataTransfer.items.add(file);
|
||||||
}
|
}
|
||||||
return dataTransfer.files
|
return dataTransfer.files;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
"use client"
|
"use client";
|
||||||
|
|
||||||
import { AnimatePresence, motion } from "framer-motion"
|
import { AnimatePresence, motion } from "framer-motion";
|
||||||
import { X } from "lucide-react"
|
import { X } from "lucide-react";
|
||||||
|
|
||||||
interface InterruptPromptProps {
|
interface InterruptPromptProps {
|
||||||
isOpen: boolean
|
isOpen: boolean;
|
||||||
close: () => void
|
close: () => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function InterruptPrompt({ isOpen, close }: InterruptPromptProps) {
|
export function InterruptPrompt({ isOpen, close }: InterruptPromptProps) {
|
||||||
|
|
@ -37,5 +37,5 @@ export function InterruptPrompt({ isOpen, close }: InterruptPromptProps) {
|
||||||
</motion.div>
|
</motion.div>
|
||||||
)}
|
)}
|
||||||
</AnimatePresence>
|
</AnimatePresence>
|
||||||
)
|
);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,12 @@
|
||||||
import React, { Suspense, useEffect, useState } from "react"
|
import React, { Suspense, useEffect, useState } from "react";
|
||||||
import Markdown from "react-markdown"
|
import Markdown from "react-markdown";
|
||||||
import remarkGfm from "remark-gfm"
|
import remarkGfm from "remark-gfm";
|
||||||
|
|
||||||
import { cn } from "@/lib/utils"
|
import { cn } from "@/lib/utils";
|
||||||
import { CopyButton } from "@/components/ui/copy-button"
|
import { CopyButton } from "@/components/ui/copy-button";
|
||||||
|
|
||||||
interface MarkdownRendererProps {
|
interface MarkdownRendererProps {
|
||||||
children: string
|
children: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function MarkdownRenderer({ children }: MarkdownRendererProps) {
|
export function MarkdownRenderer({ children }: MarkdownRendererProps) {
|
||||||
|
|
@ -16,34 +16,34 @@ export function MarkdownRenderer({ children }: MarkdownRendererProps) {
|
||||||
{children}
|
{children}
|
||||||
</Markdown>
|
</Markdown>
|
||||||
</div>
|
</div>
|
||||||
)
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
interface HighlightedPre extends React.HTMLAttributes<HTMLPreElement> {
|
interface HighlightedPre extends React.HTMLAttributes<HTMLPreElement> {
|
||||||
children: string
|
children: string;
|
||||||
language: string
|
language: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
const HighlightedPre = React.memo(
|
const HighlightedPre = React.memo(
|
||||||
({ children, language, ...props }: HighlightedPre) => {
|
({ children, language, ...props }: HighlightedPre) => {
|
||||||
const [tokens, setTokens] = useState<any[] | null>(null)
|
const [tokens, setTokens] = useState<unknown[] | null>(null);
|
||||||
const [isSupported, setIsSupported] = useState(false)
|
const [isSupported, setIsSupported] = useState(false);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
let mounted = true
|
let mounted = true;
|
||||||
|
|
||||||
const loadAndHighlight = async () => {
|
const loadAndHighlight = async () => {
|
||||||
try {
|
try {
|
||||||
const { codeToTokens, bundledLanguages } = await import("shiki")
|
const { codeToTokens, bundledLanguages } = await import("shiki");
|
||||||
|
|
||||||
if (!mounted) return
|
if (!mounted) return;
|
||||||
|
|
||||||
if (!(language in bundledLanguages)) {
|
if (!(language in bundledLanguages)) {
|
||||||
setIsSupported(false)
|
setIsSupported(false);
|
||||||
return
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
setIsSupported(true)
|
setIsSupported(true);
|
||||||
|
|
||||||
const { tokens: highlightedTokens } = await codeToTokens(children, {
|
const { tokens: highlightedTokens } = await codeToTokens(children, {
|
||||||
lang: language as keyof typeof bundledLanguages,
|
lang: language as keyof typeof bundledLanguages,
|
||||||
|
|
@ -52,31 +52,31 @@ const HighlightedPre = React.memo(
|
||||||
light: "github-light",
|
light: "github-light",
|
||||||
dark: "github-dark",
|
dark: "github-dark",
|
||||||
},
|
},
|
||||||
})
|
});
|
||||||
|
|
||||||
if (mounted) {
|
if (mounted) {
|
||||||
setTokens(highlightedTokens)
|
setTokens(highlightedTokens);
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch {
|
||||||
if (mounted) {
|
if (mounted) {
|
||||||
setIsSupported(false)
|
setIsSupported(false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
loadAndHighlight()
|
loadAndHighlight();
|
||||||
|
|
||||||
return () => {
|
return () => {
|
||||||
mounted = false
|
mounted = false;
|
||||||
}
|
};
|
||||||
}, [children, language])
|
}, [children, language]);
|
||||||
|
|
||||||
if (!isSupported) {
|
if (!isSupported) {
|
||||||
return <pre {...props}>{children}</pre>
|
return <pre {...props}>{children}</pre>;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!tokens) {
|
if (!tokens) {
|
||||||
return <pre {...props}>{children}</pre>
|
return <pre {...props}>{children}</pre>;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|
@ -89,7 +89,7 @@ const HighlightedPre = React.memo(
|
||||||
const style =
|
const style =
|
||||||
typeof token.htmlStyle === "string"
|
typeof token.htmlStyle === "string"
|
||||||
? undefined
|
? undefined
|
||||||
: token.htmlStyle
|
: token.htmlStyle;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<span
|
<span
|
||||||
|
|
@ -99,7 +99,7 @@ const HighlightedPre = React.memo(
|
||||||
>
|
>
|
||||||
{token.content}
|
{token.content}
|
||||||
</span>
|
</span>
|
||||||
)
|
);
|
||||||
})}
|
})}
|
||||||
</span>
|
</span>
|
||||||
{lineIndex !== tokens.length - 1 && "\n"}
|
{lineIndex !== tokens.length - 1 && "\n"}
|
||||||
|
|
@ -107,15 +107,15 @@ const HighlightedPre = React.memo(
|
||||||
))}
|
))}
|
||||||
</code>
|
</code>
|
||||||
</pre>
|
</pre>
|
||||||
)
|
);
|
||||||
}
|
}
|
||||||
)
|
);
|
||||||
HighlightedPre.displayName = "HighlightedCode"
|
HighlightedPre.displayName = "HighlightedCode";
|
||||||
|
|
||||||
interface CodeBlockProps extends React.HTMLAttributes<HTMLPreElement> {
|
interface CodeBlockProps extends React.HTMLAttributes<HTMLPreElement> {
|
||||||
children: React.ReactNode
|
children: React.ReactNode;
|
||||||
className?: string
|
className?: string;
|
||||||
language: string
|
language: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
const CodeBlock = ({
|
const CodeBlock = ({
|
||||||
|
|
@ -127,12 +127,12 @@ const CodeBlock = ({
|
||||||
const code =
|
const code =
|
||||||
typeof children === "string"
|
typeof children === "string"
|
||||||
? children
|
? children
|
||||||
: childrenTakeAllStringContents(children)
|
: childrenTakeAllStringContents(children);
|
||||||
|
|
||||||
const preClass = cn(
|
const preClass = cn(
|
||||||
"overflow-x-scroll rounded-md border bg-background/50 p-4 font-mono text-sm [scrollbar-width:none]",
|
"overflow-x-scroll rounded-md border bg-background/50 p-4 font-mono text-sm [scrollbar-width:none]",
|
||||||
className
|
className
|
||||||
)
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="group/code relative mb-4">
|
<div className="group/code relative mb-4">
|
||||||
|
|
@ -152,27 +152,27 @@ const CodeBlock = ({
|
||||||
<CopyButton content={code} copyMessage="Copied code to clipboard" />
|
<CopyButton content={code} copyMessage="Copied code to clipboard" />
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)
|
);
|
||||||
}
|
};
|
||||||
|
|
||||||
function childrenTakeAllStringContents(element: any): string {
|
function childrenTakeAllStringContents(element: unknown): string {
|
||||||
if (typeof element === "string") {
|
if (typeof element === "string") {
|
||||||
return element
|
return element;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (element?.props?.children) {
|
if (element?.props?.children) {
|
||||||
let children = element.props.children
|
const children = element.props.children;
|
||||||
|
|
||||||
if (Array.isArray(children)) {
|
if (Array.isArray(children)) {
|
||||||
return children
|
return children
|
||||||
.map((child) => childrenTakeAllStringContents(child))
|
.map(child => childrenTakeAllStringContents(child))
|
||||||
.join("")
|
.join("");
|
||||||
} else {
|
} else {
|
||||||
return childrenTakeAllStringContents(children)
|
return childrenTakeAllStringContents(children);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ""
|
return "";
|
||||||
}
|
}
|
||||||
|
|
||||||
const COMPONENTS = {
|
const COMPONENTS = {
|
||||||
|
|
@ -184,8 +184,14 @@ const COMPONENTS = {
|
||||||
strong: withClass("strong", "font-semibold"),
|
strong: withClass("strong", "font-semibold"),
|
||||||
a: withClass("a", "text-primary underline underline-offset-2"),
|
a: withClass("a", "text-primary underline underline-offset-2"),
|
||||||
blockquote: withClass("blockquote", "border-l-2 border-primary pl-4"),
|
blockquote: withClass("blockquote", "border-l-2 border-primary pl-4"),
|
||||||
code: ({ children, className, node, ...rest }: any) => {
|
code: ({
|
||||||
const match = /language-(\w+)/.exec(className || "")
|
children,
|
||||||
|
className,
|
||||||
|
}: {
|
||||||
|
children: React.ReactNode;
|
||||||
|
className?: string;
|
||||||
|
}) => {
|
||||||
|
const match = /language-(\w+)/.exec(className || "");
|
||||||
return match ? (
|
return match ? (
|
||||||
<CodeBlock className={className} language={match[1]} {...rest}>
|
<CodeBlock className={className} language={match[1]} {...rest}>
|
||||||
{children}
|
{children}
|
||||||
|
|
@ -199,9 +205,9 @@ const COMPONENTS = {
|
||||||
>
|
>
|
||||||
{children}
|
{children}
|
||||||
</code>
|
</code>
|
||||||
)
|
);
|
||||||
},
|
},
|
||||||
pre: ({ children }: any) => children,
|
pre: ({ children }: { children: React.ReactNode }) => children,
|
||||||
ol: withClass("ol", "list-decimal space-y-2 pl-6"),
|
ol: withClass("ol", "list-decimal space-y-2 pl-6"),
|
||||||
ul: withClass("ul", "list-disc space-y-2 pl-6"),
|
ul: withClass("ul", "list-disc space-y-2 pl-6"),
|
||||||
li: withClass("li", "my-1.5"),
|
li: withClass("li", "my-1.5"),
|
||||||
|
|
@ -220,14 +226,14 @@ const COMPONENTS = {
|
||||||
tr: withClass("tr", "m-0 border-t p-0 even:bg-muted"),
|
tr: withClass("tr", "m-0 border-t p-0 even:bg-muted"),
|
||||||
p: withClass("p", "whitespace-pre-wrap"),
|
p: withClass("p", "whitespace-pre-wrap"),
|
||||||
hr: withClass("hr", "border-foreground/20"),
|
hr: withClass("hr", "border-foreground/20"),
|
||||||
}
|
};
|
||||||
|
|
||||||
function withClass(Tag: keyof JSX.IntrinsicElements, classes: string) {
|
function withClass(Tag: keyof JSX.IntrinsicElements, classes: string) {
|
||||||
const Component = ({ node, ...props }: any) => (
|
const Component = ({ ...props }: Record<string, unknown>) => (
|
||||||
<Tag className={classes} {...props} />
|
<Tag className={classes} {...props} />
|
||||||
)
|
);
|
||||||
Component.displayName = Tag
|
Component.displayName = Tag;
|
||||||
return Component
|
return Component;
|
||||||
}
|
}
|
||||||
|
|
||||||
export default MarkdownRenderer
|
export default MarkdownRenderer;
|
||||||
|
|
|
||||||
|
|
@ -1,41 +1,41 @@
|
||||||
"use client"
|
"use client";
|
||||||
|
|
||||||
import React, { useEffect, useRef, useState } from "react"
|
import React, { useEffect, useRef, useState } from "react";
|
||||||
import { AnimatePresence, motion } from "framer-motion"
|
import { AnimatePresence, motion } from "framer-motion";
|
||||||
import { ArrowUp, Info, Loader2, Mic, Paperclip, Square } from "lucide-react"
|
import { ArrowUp, Info, Loader2, Mic, Paperclip, Square } from "lucide-react";
|
||||||
import { omit } from "remeda"
|
import { omit } from "remeda";
|
||||||
|
|
||||||
import { cn } from "@/lib/utils"
|
import { cn } from "@/lib/utils";
|
||||||
import { useAudioRecording } from "@/hooks/use-audio-recording"
|
import { useAudioRecording } from "@/hooks/use-audio-recording";
|
||||||
import { useAutosizeTextArea } from "@/hooks/use-autosize-textarea"
|
import { useAutosizeTextArea } from "@/hooks/use-autosize-textarea";
|
||||||
import { AudioVisualizer } from "@/components/ui/audio-visualizer"
|
import { AudioVisualizer } from "@/components/ui/audio-visualizer";
|
||||||
import { Button } from "@/components/ui/button"
|
import { Button } from "@/components/ui/button";
|
||||||
import { FilePreview } from "@/components/ui/file-preview"
|
import { FilePreview } from "@/components/ui/file-preview";
|
||||||
import { InterruptPrompt } from "@/components/chat-playground/interrupt-prompt"
|
import { InterruptPrompt } from "@/components/chat-playground/interrupt-prompt";
|
||||||
|
|
||||||
interface MessageInputBaseProps
|
interface MessageInputBaseProps
|
||||||
extends React.TextareaHTMLAttributes<HTMLTextAreaElement> {
|
extends React.TextareaHTMLAttributes<HTMLTextAreaElement> {
|
||||||
value: string
|
value: string;
|
||||||
submitOnEnter?: boolean
|
submitOnEnter?: boolean;
|
||||||
stop?: () => void
|
stop?: () => void;
|
||||||
isGenerating: boolean
|
isGenerating: boolean;
|
||||||
enableInterrupt?: boolean
|
enableInterrupt?: boolean;
|
||||||
transcribeAudio?: (blob: Blob) => Promise<string>
|
transcribeAudio?: (blob: Blob) => Promise<string>;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface MessageInputWithoutAttachmentProps extends MessageInputBaseProps {
|
interface MessageInputWithoutAttachmentProps extends MessageInputBaseProps {
|
||||||
allowAttachments?: false
|
allowAttachments?: false;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface MessageInputWithAttachmentsProps extends MessageInputBaseProps {
|
interface MessageInputWithAttachmentsProps extends MessageInputBaseProps {
|
||||||
allowAttachments: true
|
allowAttachments: true;
|
||||||
files: File[] | null
|
files: File[] | null;
|
||||||
setFiles: React.Dispatch<React.SetStateAction<File[] | null>>
|
setFiles: React.Dispatch<React.SetStateAction<File[] | null>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
type MessageInputProps =
|
type MessageInputProps =
|
||||||
| MessageInputWithoutAttachmentProps
|
| MessageInputWithoutAttachmentProps
|
||||||
| MessageInputWithAttachmentsProps
|
| MessageInputWithAttachmentsProps;
|
||||||
|
|
||||||
export function MessageInput({
|
export function MessageInput({
|
||||||
placeholder = "Ask AI...",
|
placeholder = "Ask AI...",
|
||||||
|
|
@ -48,8 +48,8 @@ export function MessageInput({
|
||||||
transcribeAudio,
|
transcribeAudio,
|
||||||
...props
|
...props
|
||||||
}: MessageInputProps) {
|
}: MessageInputProps) {
|
||||||
const [isDragging, setIsDragging] = useState(false)
|
const [isDragging, setIsDragging] = useState(false);
|
||||||
const [showInterruptPrompt, setShowInterruptPrompt] = useState(false)
|
const [showInterruptPrompt, setShowInterruptPrompt] = useState(false);
|
||||||
|
|
||||||
const {
|
const {
|
||||||
isListening,
|
isListening,
|
||||||
|
|
@ -61,123 +61,124 @@ export function MessageInput({
|
||||||
stopRecording,
|
stopRecording,
|
||||||
} = useAudioRecording({
|
} = useAudioRecording({
|
||||||
transcribeAudio,
|
transcribeAudio,
|
||||||
onTranscriptionComplete: (text) => {
|
onTranscriptionComplete: text => {
|
||||||
props.onChange?.({ target: { value: text } } as any)
|
props.onChange?.({
|
||||||
|
target: { value: text },
|
||||||
|
} as React.ChangeEvent<HTMLTextAreaElement>);
|
||||||
},
|
},
|
||||||
})
|
});
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!isGenerating) {
|
if (!isGenerating) {
|
||||||
setShowInterruptPrompt(false)
|
setShowInterruptPrompt(false);
|
||||||
}
|
}
|
||||||
}, [isGenerating])
|
}, [isGenerating]);
|
||||||
|
|
||||||
const addFiles = (files: File[] | null) => {
|
const addFiles = (files: File[] | null) => {
|
||||||
if (props.allowAttachments) {
|
if (props.allowAttachments) {
|
||||||
props.setFiles((currentFiles) => {
|
props.setFiles(currentFiles => {
|
||||||
if (currentFiles === null) {
|
if (currentFiles === null) {
|
||||||
return files
|
return files;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (files === null) {
|
if (files === null) {
|
||||||
return currentFiles
|
return currentFiles;
|
||||||
}
|
}
|
||||||
|
|
||||||
return [...currentFiles, ...files]
|
return [...currentFiles, ...files];
|
||||||
})
|
});
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
const onDragOver = (event: React.DragEvent) => {
|
const onDragOver = (event: React.DragEvent) => {
|
||||||
if (props.allowAttachments !== true) return
|
if (props.allowAttachments !== true) return;
|
||||||
event.preventDefault()
|
event.preventDefault();
|
||||||
setIsDragging(true)
|
setIsDragging(true);
|
||||||
}
|
};
|
||||||
|
|
||||||
const onDragLeave = (event: React.DragEvent) => {
|
const onDragLeave = (event: React.DragEvent) => {
|
||||||
if (props.allowAttachments !== true) return
|
if (props.allowAttachments !== true) return;
|
||||||
event.preventDefault()
|
event.preventDefault();
|
||||||
setIsDragging(false)
|
setIsDragging(false);
|
||||||
}
|
};
|
||||||
|
|
||||||
const onDrop = (event: React.DragEvent) => {
|
const onDrop = (event: React.DragEvent) => {
|
||||||
setIsDragging(false)
|
setIsDragging(false);
|
||||||
if (props.allowAttachments !== true) return
|
if (props.allowAttachments !== true) return;
|
||||||
event.preventDefault()
|
event.preventDefault();
|
||||||
const dataTransfer = event.dataTransfer
|
const dataTransfer = event.dataTransfer;
|
||||||
if (dataTransfer.files.length) {
|
if (dataTransfer.files.length) {
|
||||||
addFiles(Array.from(dataTransfer.files))
|
addFiles(Array.from(dataTransfer.files));
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
const onPaste = (event: React.ClipboardEvent) => {
|
const onPaste = (event: React.ClipboardEvent) => {
|
||||||
const items = event.clipboardData?.items
|
const items = event.clipboardData?.items;
|
||||||
if (!items) return
|
if (!items) return;
|
||||||
|
|
||||||
const text = event.clipboardData.getData("text")
|
const text = event.clipboardData.getData("text");
|
||||||
if (text && text.length > 500 && props.allowAttachments) {
|
if (text && text.length > 500 && props.allowAttachments) {
|
||||||
event.preventDefault()
|
event.preventDefault();
|
||||||
const blob = new Blob([text], { type: "text/plain" })
|
const blob = new Blob([text], { type: "text/plain" });
|
||||||
const file = new File([blob], "Pasted text", {
|
const file = new File([blob], "Pasted text", {
|
||||||
type: "text/plain",
|
type: "text/plain",
|
||||||
lastModified: Date.now(),
|
lastModified: Date.now(),
|
||||||
})
|
});
|
||||||
addFiles([file])
|
addFiles([file]);
|
||||||
return
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const files = Array.from(items)
|
const files = Array.from(items)
|
||||||
.map((item) => item.getAsFile())
|
.map(item => item.getAsFile())
|
||||||
.filter((file) => file !== null)
|
.filter(file => file !== null);
|
||||||
|
|
||||||
if (props.allowAttachments && files.length > 0) {
|
if (props.allowAttachments && files.length > 0) {
|
||||||
addFiles(files)
|
addFiles(files);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
const onKeyDown = (event: React.KeyboardEvent<HTMLTextAreaElement>) => {
|
const onKeyDown = (event: React.KeyboardEvent<HTMLTextAreaElement>) => {
|
||||||
if (submitOnEnter && event.key === "Enter" && !event.shiftKey) {
|
if (submitOnEnter && event.key === "Enter" && !event.shiftKey) {
|
||||||
event.preventDefault()
|
event.preventDefault();
|
||||||
|
|
||||||
if (isGenerating && stop && enableInterrupt) {
|
if (isGenerating && stop && enableInterrupt) {
|
||||||
if (showInterruptPrompt) {
|
if (showInterruptPrompt) {
|
||||||
stop()
|
stop();
|
||||||
setShowInterruptPrompt(false)
|
setShowInterruptPrompt(false);
|
||||||
event.currentTarget.form?.requestSubmit()
|
event.currentTarget.form?.requestSubmit();
|
||||||
} else if (
|
} else if (
|
||||||
props.value ||
|
props.value ||
|
||||||
(props.allowAttachments && props.files?.length)
|
(props.allowAttachments && props.files?.length)
|
||||||
) {
|
) {
|
||||||
setShowInterruptPrompt(true)
|
setShowInterruptPrompt(true);
|
||||||
return
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
event.currentTarget.form?.requestSubmit()
|
event.currentTarget.form?.requestSubmit();
|
||||||
}
|
}
|
||||||
|
|
||||||
onKeyDownProp?.(event)
|
onKeyDownProp?.(event);
|
||||||
}
|
};
|
||||||
|
|
||||||
const textAreaRef = useRef<HTMLTextAreaElement>(null)
|
const textAreaRef = useRef<HTMLTextAreaElement>(null);
|
||||||
const [textAreaHeight, setTextAreaHeight] = useState<number>(0)
|
const [textAreaHeight, setTextAreaHeight] = useState<number>(0);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (textAreaRef.current) {
|
if (textAreaRef.current) {
|
||||||
setTextAreaHeight(textAreaRef.current.offsetHeight)
|
setTextAreaHeight(textAreaRef.current.offsetHeight);
|
||||||
}
|
}
|
||||||
}, [props.value])
|
}, [props.value]);
|
||||||
|
|
||||||
const showFileList =
|
const showFileList =
|
||||||
props.allowAttachments && props.files && props.files.length > 0
|
props.allowAttachments && props.files && props.files.length > 0;
|
||||||
|
|
||||||
|
|
||||||
useAutosizeTextArea({
|
useAutosizeTextArea({
|
||||||
ref: textAreaRef,
|
ref: textAreaRef,
|
||||||
maxHeight: 240,
|
maxHeight: 240,
|
||||||
borderWidth: 1,
|
borderWidth: 1,
|
||||||
dependencies: [props.value, showFileList],
|
dependencies: [props.value, showFileList],
|
||||||
})
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
|
|
@ -220,24 +221,24 @@ export function MessageInput({
|
||||||
<div className="absolute inset-x-3 bottom-0 z-20 overflow-x-scroll py-3">
|
<div className="absolute inset-x-3 bottom-0 z-20 overflow-x-scroll py-3">
|
||||||
<div className="flex space-x-3">
|
<div className="flex space-x-3">
|
||||||
<AnimatePresence mode="popLayout">
|
<AnimatePresence mode="popLayout">
|
||||||
{props.files?.map((file) => {
|
{props.files?.map(file => {
|
||||||
return (
|
return (
|
||||||
<FilePreview
|
<FilePreview
|
||||||
key={file.name + String(file.lastModified)}
|
key={file.name + String(file.lastModified)}
|
||||||
file={file}
|
file={file}
|
||||||
onRemove={() => {
|
onRemove={() => {
|
||||||
props.setFiles((files) => {
|
props.setFiles(files => {
|
||||||
if (!files) return null
|
if (!files) return null;
|
||||||
|
|
||||||
const filtered = Array.from(files).filter(
|
const filtered = Array.from(files).filter(
|
||||||
(f) => f !== file
|
f => f !== file
|
||||||
)
|
);
|
||||||
if (filtered.length === 0) return null
|
if (filtered.length === 0) return null;
|
||||||
return filtered
|
return filtered;
|
||||||
})
|
});
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
)
|
);
|
||||||
})}
|
})}
|
||||||
</AnimatePresence>
|
</AnimatePresence>
|
||||||
</div>
|
</div>
|
||||||
|
|
@ -256,8 +257,8 @@ export function MessageInput({
|
||||||
aria-label="Attach a file"
|
aria-label="Attach a file"
|
||||||
disabled={true}
|
disabled={true}
|
||||||
onClick={async () => {
|
onClick={async () => {
|
||||||
const files = await showFileUploadDialog()
|
const files = await showFileUploadDialog();
|
||||||
addFiles(files)
|
addFiles(files);
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Paperclip className="h-4 w-4" />
|
<Paperclip className="h-4 w-4" />
|
||||||
|
|
@ -308,12 +309,12 @@ export function MessageInput({
|
||||||
onStopRecording={stopRecording}
|
onStopRecording={stopRecording}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
)
|
);
|
||||||
}
|
}
|
||||||
MessageInput.displayName = "MessageInput"
|
MessageInput.displayName = "MessageInput";
|
||||||
|
|
||||||
interface FileUploadOverlayProps {
|
interface FileUploadOverlayProps {
|
||||||
isDragging: boolean
|
isDragging: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
function FileUploadOverlay({ isDragging }: FileUploadOverlayProps) {
|
function FileUploadOverlay({ isDragging }: FileUploadOverlayProps) {
|
||||||
|
|
@ -333,29 +334,29 @@ function FileUploadOverlay({ isDragging }: FileUploadOverlayProps) {
|
||||||
</motion.div>
|
</motion.div>
|
||||||
)}
|
)}
|
||||||
</AnimatePresence>
|
</AnimatePresence>
|
||||||
)
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
function showFileUploadDialog() {
|
function showFileUploadDialog() {
|
||||||
const input = document.createElement("input")
|
const input = document.createElement("input");
|
||||||
|
|
||||||
input.type = "file"
|
input.type = "file";
|
||||||
input.multiple = true
|
input.multiple = true;
|
||||||
input.accept = "*/*"
|
input.accept = "*/*";
|
||||||
input.click()
|
input.click();
|
||||||
|
|
||||||
return new Promise<File[] | null>((resolve) => {
|
return new Promise<File[] | null>(resolve => {
|
||||||
input.onchange = (e) => {
|
input.onchange = e => {
|
||||||
const files = (e.currentTarget as HTMLInputElement).files
|
const files = (e.currentTarget as HTMLInputElement).files;
|
||||||
|
|
||||||
if (files) {
|
if (files) {
|
||||||
resolve(Array.from(files))
|
resolve(Array.from(files));
|
||||||
return
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
resolve(null)
|
resolve(null);
|
||||||
}
|
};
|
||||||
})
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
function TranscribingOverlay() {
|
function TranscribingOverlay() {
|
||||||
|
|
@ -385,12 +386,12 @@ function TranscribingOverlay() {
|
||||||
Transcribing audio...
|
Transcribing audio...
|
||||||
</p>
|
</p>
|
||||||
</motion.div>
|
</motion.div>
|
||||||
)
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
interface RecordingPromptProps {
|
interface RecordingPromptProps {
|
||||||
isVisible: boolean
|
isVisible: boolean;
|
||||||
onStopRecording: () => void
|
onStopRecording: () => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
function RecordingPrompt({ isVisible, onStopRecording }: RecordingPromptProps) {
|
function RecordingPrompt({ isVisible, onStopRecording }: RecordingPromptProps) {
|
||||||
|
|
@ -418,15 +419,15 @@ function RecordingPrompt({ isVisible, onStopRecording }: RecordingPromptProps) {
|
||||||
</motion.div>
|
</motion.div>
|
||||||
)}
|
)}
|
||||||
</AnimatePresence>
|
</AnimatePresence>
|
||||||
)
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
interface RecordingControlsProps {
|
interface RecordingControlsProps {
|
||||||
isRecording: boolean
|
isRecording: boolean;
|
||||||
isTranscribing: boolean
|
isTranscribing: boolean;
|
||||||
audioStream: MediaStream | null
|
audioStream: MediaStream | null;
|
||||||
textAreaHeight: number
|
textAreaHeight: number;
|
||||||
onStopRecording: () => void
|
onStopRecording: () => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
function RecordingControls({
|
function RecordingControls({
|
||||||
|
|
@ -448,7 +449,7 @@ function RecordingControls({
|
||||||
onClick={onStopRecording}
|
onClick={onStopRecording}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
)
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isTranscribing) {
|
if (isTranscribing) {
|
||||||
|
|
@ -459,8 +460,8 @@ function RecordingControls({
|
||||||
>
|
>
|
||||||
<TranscribingOverlay />
|
<TranscribingOverlay />
|
||||||
</div>
|
</div>
|
||||||
)
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
return null
|
return null;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,18 +2,18 @@ import {
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
type ChatMessageProps,
|
type ChatMessageProps,
|
||||||
type Message,
|
type Message,
|
||||||
} from "@/components/chat-playground/chat-message"
|
} from "@/components/chat-playground/chat-message";
|
||||||
import { TypingIndicator } from "@/components/chat-playground/typing-indicator"
|
import { TypingIndicator } from "@/components/chat-playground/typing-indicator";
|
||||||
|
|
||||||
type AdditionalMessageOptions = Omit<ChatMessageProps, keyof Message>
|
type AdditionalMessageOptions = Omit<ChatMessageProps, keyof Message>;
|
||||||
|
|
||||||
interface MessageListProps {
|
interface MessageListProps {
|
||||||
messages: Message[]
|
messages: Message[];
|
||||||
showTimeStamps?: boolean
|
showTimeStamps?: boolean;
|
||||||
isTyping?: boolean
|
isTyping?: boolean;
|
||||||
messageOptions?:
|
messageOptions?:
|
||||||
| AdditionalMessageOptions
|
| AdditionalMessageOptions
|
||||||
| ((message: Message) => AdditionalMessageOptions)
|
| ((message: Message) => AdditionalMessageOptions);
|
||||||
}
|
}
|
||||||
|
|
||||||
export function MessageList({
|
export function MessageList({
|
||||||
|
|
@ -28,7 +28,7 @@ export function MessageList({
|
||||||
const additionalOptions =
|
const additionalOptions =
|
||||||
typeof messageOptions === "function"
|
typeof messageOptions === "function"
|
||||||
? messageOptions(message)
|
? messageOptions(message)
|
||||||
: messageOptions
|
: messageOptions;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ChatMessage
|
<ChatMessage
|
||||||
|
|
@ -37,9 +37,9 @@ export function MessageList({
|
||||||
{...message}
|
{...message}
|
||||||
{...additionalOptions}
|
{...additionalOptions}
|
||||||
/>
|
/>
|
||||||
)
|
);
|
||||||
})}
|
})}
|
||||||
{isTyping && <TypingIndicator />}
|
{isTyping && <TypingIndicator />}
|
||||||
</div>
|
</div>
|
||||||
)
|
);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
interface PromptSuggestionsProps {
|
interface PromptSuggestionsProps {
|
||||||
label: string
|
label: string;
|
||||||
append: (message: { role: "user"; content: string }) => void
|
append: (message: { role: "user"; content: string }) => void;
|
||||||
suggestions: string[]
|
suggestions: string[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export function PromptSuggestions({
|
export function PromptSuggestions({
|
||||||
|
|
@ -13,7 +13,7 @@ export function PromptSuggestions({
|
||||||
<div className="space-y-6">
|
<div className="space-y-6">
|
||||||
<h2 className="text-center text-2xl font-bold">{label}</h2>
|
<h2 className="text-center text-2xl font-bold">{label}</h2>
|
||||||
<div className="flex gap-6 text-sm">
|
<div className="flex gap-6 text-sm">
|
||||||
{suggestions.map((suggestion) => (
|
{suggestions.map(suggestion => (
|
||||||
<button
|
<button
|
||||||
key={suggestion}
|
key={suggestion}
|
||||||
onClick={() => append({ role: "user", content: suggestion })}
|
onClick={() => append({ role: "user", content: suggestion })}
|
||||||
|
|
@ -24,5 +24,5 @@ export function PromptSuggestions({
|
||||||
))}
|
))}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)
|
);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
import { Dot } from "lucide-react"
|
import { Dot } from "lucide-react";
|
||||||
|
|
||||||
export function TypingIndicator() {
|
export function TypingIndicator() {
|
||||||
return (
|
return (
|
||||||
|
|
@ -11,5 +11,5 @@ export function TypingIndicator() {
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)
|
);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -56,18 +56,19 @@ const manageItems = [
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
const optimizeItems: { title: string; url: string; icon: React.ElementType }[] = [
|
const optimizeItems: { title: string; url: string; icon: React.ElementType }[] =
|
||||||
|
[
|
||||||
{
|
{
|
||||||
title: "Evaluations",
|
title: "Evaluations",
|
||||||
url: "",
|
url: "",
|
||||||
icon: Compass,
|
icon: Compass,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
title: "Fine-tuning",
|
title: "Fine-tuning",
|
||||||
url: "",
|
url: "",
|
||||||
icon: Settings2,
|
icon: Settings2,
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
interface SidebarItem {
|
interface SidebarItem {
|
||||||
title: string;
|
title: string;
|
||||||
|
|
@ -79,7 +80,7 @@ export function AppSidebar() {
|
||||||
const pathname = usePathname();
|
const pathname = usePathname();
|
||||||
|
|
||||||
const renderSidebarItems = (items: SidebarItem[]) => {
|
const renderSidebarItems = (items: SidebarItem[]) => {
|
||||||
return items.map((item) => {
|
return items.map(item => {
|
||||||
const isActive = pathname.startsWith(item.url);
|
const isActive = pathname.startsWith(item.url);
|
||||||
return (
|
return (
|
||||||
<SidebarMenuItem key={item.title}>
|
<SidebarMenuItem key={item.title}>
|
||||||
|
|
@ -88,14 +89,14 @@ export function AppSidebar() {
|
||||||
className={cn(
|
className={cn(
|
||||||
"justify-start",
|
"justify-start",
|
||||||
isActive &&
|
isActive &&
|
||||||
"bg-gray-200 dark:bg-gray-700 hover:bg-gray-200 dark:hover:bg-gray-700 text-gray-900 dark:text-gray-100",
|
"bg-gray-200 dark:bg-gray-700 hover:bg-gray-200 dark:hover:bg-gray-700 text-gray-900 dark:text-gray-100"
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
<Link href={item.url}>
|
<Link href={item.url}>
|
||||||
<item.icon
|
<item.icon
|
||||||
className={cn(
|
className={cn(
|
||||||
isActive && "text-gray-900 dark:text-gray-100",
|
isActive && "text-gray-900 dark:text-gray-100",
|
||||||
"mr-2 h-4 w-4",
|
"mr-2 h-4 w-4"
|
||||||
)}
|
)}
|
||||||
/>
|
/>
|
||||||
<span>{item.title}</span>
|
<span>{item.title}</span>
|
||||||
|
|
@ -106,46 +107,48 @@ export function AppSidebar() {
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Sidebar>
|
<Sidebar>
|
||||||
<SidebarHeader>
|
<SidebarHeader>
|
||||||
<Link href="/">Llama Stack</Link>
|
<Link href="/">Llama Stack</Link>
|
||||||
</SidebarHeader>
|
</SidebarHeader>
|
||||||
<SidebarContent>
|
<SidebarContent>
|
||||||
<SidebarGroup>
|
<SidebarGroup>
|
||||||
<SidebarGroupLabel>Create</SidebarGroupLabel>
|
<SidebarGroupLabel>Create</SidebarGroupLabel>
|
||||||
<SidebarGroupContent>
|
<SidebarGroupContent>
|
||||||
<SidebarMenu>{renderSidebarItems(createItems)}</SidebarMenu>
|
<SidebarMenu>{renderSidebarItems(createItems)}</SidebarMenu>
|
||||||
</SidebarGroupContent>
|
</SidebarGroupContent>
|
||||||
</SidebarGroup>
|
</SidebarGroup>
|
||||||
|
|
||||||
<SidebarGroup>
|
<SidebarGroup>
|
||||||
<SidebarGroupLabel>Manage</SidebarGroupLabel>
|
<SidebarGroupLabel>Manage</SidebarGroupLabel>
|
||||||
<SidebarGroupContent>
|
<SidebarGroupContent>
|
||||||
<SidebarMenu>{renderSidebarItems(manageItems)}</SidebarMenu>
|
<SidebarMenu>{renderSidebarItems(manageItems)}</SidebarMenu>
|
||||||
</SidebarGroupContent>
|
</SidebarGroupContent>
|
||||||
</SidebarGroup>
|
</SidebarGroup>
|
||||||
|
|
||||||
<SidebarGroup>
|
<SidebarGroup>
|
||||||
<SidebarGroupLabel>Optimize</SidebarGroupLabel>
|
<SidebarGroupLabel>Optimize</SidebarGroupLabel>
|
||||||
<SidebarGroupContent>
|
<SidebarGroupContent>
|
||||||
<SidebarMenu>
|
<SidebarMenu>
|
||||||
{optimizeItems.map((item) => (
|
{optimizeItems.map(item => (
|
||||||
<SidebarMenuItem key={item.title}>
|
<SidebarMenuItem key={item.title}>
|
||||||
<SidebarMenuButton
|
<SidebarMenuButton
|
||||||
disabled
|
disabled
|
||||||
className="justify-start opacity-60 cursor-not-allowed"
|
className="justify-start opacity-60 cursor-not-allowed"
|
||||||
>
|
>
|
||||||
<item.icon className="mr-2 h-4 w-4" />
|
<item.icon className="mr-2 h-4 w-4" />
|
||||||
<span>{item.title}</span>
|
<span>{item.title}</span>
|
||||||
<span className="ml-2 text-xs text-gray-500">(Coming Soon)</span>
|
<span className="ml-2 text-xs text-gray-500">
|
||||||
</SidebarMenuButton>
|
(Coming Soon)
|
||||||
</SidebarMenuItem>
|
</span>
|
||||||
))}
|
</SidebarMenuButton>
|
||||||
</SidebarMenu>
|
</SidebarMenuItem>
|
||||||
</SidebarGroupContent>
|
))}
|
||||||
</SidebarGroup>
|
</SidebarMenu>
|
||||||
</SidebarContent>
|
</SidebarGroupContent>
|
||||||
</Sidebar>
|
</SidebarGroup>
|
||||||
|
</SidebarContent>
|
||||||
|
</Sidebar>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import React from "react";
|
||||||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||||
import { Skeleton } from "@/components/ui/skeleton";
|
import { Skeleton } from "@/components/ui/skeleton";
|
||||||
|
|
||||||
export function DetailLoadingView({ title }: { title: string }) {
|
export function DetailLoadingView() {
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Skeleton className="h-8 w-3/4 mb-6" /> {/* Title Skeleton */}
|
<Skeleton className="h-8 w-3/4 mb-6" /> {/* Title Skeleton */}
|
||||||
|
|
|
||||||
|
|
@ -67,7 +67,7 @@ describe("LogsTable Viewport Loading", () => {
|
||||||
() => {
|
() => {
|
||||||
expect(mockLoadMore).toHaveBeenCalled();
|
expect(mockLoadMore).toHaveBeenCalled();
|
||||||
},
|
},
|
||||||
{ timeout: 300 },
|
{ timeout: 300 }
|
||||||
);
|
);
|
||||||
|
|
||||||
expect(mockLoadMore).toHaveBeenCalledTimes(1);
|
expect(mockLoadMore).toHaveBeenCalledTimes(1);
|
||||||
|
|
@ -81,11 +81,11 @@ describe("LogsTable Viewport Loading", () => {
|
||||||
{...defaultProps}
|
{...defaultProps}
|
||||||
status="loading-more"
|
status="loading-more"
|
||||||
onLoadMore={mockLoadMore}
|
onLoadMore={mockLoadMore}
|
||||||
/>,
|
/>
|
||||||
);
|
);
|
||||||
|
|
||||||
// Wait for possible triggers
|
// Wait for possible triggers
|
||||||
await new Promise((resolve) => setTimeout(resolve, 300));
|
await new Promise(resolve => setTimeout(resolve, 300));
|
||||||
|
|
||||||
expect(mockLoadMore).not.toHaveBeenCalled();
|
expect(mockLoadMore).not.toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
|
@ -94,15 +94,11 @@ describe("LogsTable Viewport Loading", () => {
|
||||||
const mockLoadMore = jest.fn();
|
const mockLoadMore = jest.fn();
|
||||||
|
|
||||||
render(
|
render(
|
||||||
<LogsTable
|
<LogsTable {...defaultProps} status="loading" onLoadMore={mockLoadMore} />
|
||||||
{...defaultProps}
|
|
||||||
status="loading"
|
|
||||||
onLoadMore={mockLoadMore}
|
|
||||||
/>,
|
|
||||||
);
|
);
|
||||||
|
|
||||||
// Wait for possible triggers
|
// Wait for possible triggers
|
||||||
await new Promise((resolve) => setTimeout(resolve, 300));
|
await new Promise(resolve => setTimeout(resolve, 300));
|
||||||
|
|
||||||
expect(mockLoadMore).not.toHaveBeenCalled();
|
expect(mockLoadMore).not.toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
|
@ -111,18 +107,18 @@ describe("LogsTable Viewport Loading", () => {
|
||||||
const mockLoadMore = jest.fn();
|
const mockLoadMore = jest.fn();
|
||||||
|
|
||||||
render(
|
render(
|
||||||
<LogsTable {...defaultProps} hasMore={false} onLoadMore={mockLoadMore} />,
|
<LogsTable {...defaultProps} hasMore={false} onLoadMore={mockLoadMore} />
|
||||||
);
|
);
|
||||||
|
|
||||||
// Wait for possible triggers
|
// Wait for possible triggers
|
||||||
await new Promise((resolve) => setTimeout(resolve, 300));
|
await new Promise(resolve => setTimeout(resolve, 300));
|
||||||
|
|
||||||
expect(mockLoadMore).not.toHaveBeenCalled();
|
expect(mockLoadMore).not.toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
|
||||||
test("sentinel element should not be rendered when loading", () => {
|
test("sentinel element should not be rendered when loading", () => {
|
||||||
const { container } = render(
|
const { container } = render(
|
||||||
<LogsTable {...defaultProps} status="loading-more" />,
|
<LogsTable {...defaultProps} status="loading-more" />
|
||||||
);
|
);
|
||||||
|
|
||||||
// Check that no sentinel row with height: 1 exists
|
// Check that no sentinel row with height: 1 exists
|
||||||
|
|
@ -132,7 +128,7 @@ describe("LogsTable Viewport Loading", () => {
|
||||||
|
|
||||||
test("sentinel element should be rendered when not loading and hasMore", () => {
|
test("sentinel element should be rendered when not loading and hasMore", () => {
|
||||||
const { container } = render(
|
const { container } = render(
|
||||||
<LogsTable {...defaultProps} hasMore={true} status="idle" />,
|
<LogsTable {...defaultProps} hasMore={true} status="idle" />
|
||||||
);
|
);
|
||||||
|
|
||||||
// Check that sentinel row exists
|
// Check that sentinel row exists
|
||||||
|
|
|
||||||
|
|
@ -70,7 +70,7 @@ describe("LogsTable", () => {
|
||||||
describe("Loading State", () => {
|
describe("Loading State", () => {
|
||||||
test("renders skeleton UI when isLoading is true", () => {
|
test("renders skeleton UI when isLoading is true", () => {
|
||||||
const { container } = render(
|
const { container } = render(
|
||||||
<LogsTable {...defaultProps} status="loading" />,
|
<LogsTable {...defaultProps} status="loading" />
|
||||||
);
|
);
|
||||||
|
|
||||||
// Check for skeleton in the table caption
|
// Check for skeleton in the table caption
|
||||||
|
|
@ -78,7 +78,7 @@ describe("LogsTable", () => {
|
||||||
expect(tableCaption).toBeInTheDocument();
|
expect(tableCaption).toBeInTheDocument();
|
||||||
if (tableCaption) {
|
if (tableCaption) {
|
||||||
const captionSkeleton = tableCaption.querySelector(
|
const captionSkeleton = tableCaption.querySelector(
|
||||||
'[data-slot="skeleton"]',
|
'[data-slot="skeleton"]'
|
||||||
);
|
);
|
||||||
expect(captionSkeleton).toBeInTheDocument();
|
expect(captionSkeleton).toBeInTheDocument();
|
||||||
}
|
}
|
||||||
|
|
@ -88,7 +88,7 @@ describe("LogsTable", () => {
|
||||||
expect(tableBody).toBeInTheDocument();
|
expect(tableBody).toBeInTheDocument();
|
||||||
if (tableBody) {
|
if (tableBody) {
|
||||||
const bodySkeletons = tableBody.querySelectorAll(
|
const bodySkeletons = tableBody.querySelectorAll(
|
||||||
'[data-slot="skeleton"]',
|
'[data-slot="skeleton"]'
|
||||||
);
|
);
|
||||||
expect(bodySkeletons.length).toBeGreaterThan(0);
|
expect(bodySkeletons.length).toBeGreaterThan(0);
|
||||||
}
|
}
|
||||||
|
|
@ -102,7 +102,7 @@ describe("LogsTable", () => {
|
||||||
|
|
||||||
test("renders correct number of skeleton rows", () => {
|
test("renders correct number of skeleton rows", () => {
|
||||||
const { container } = render(
|
const { container } = render(
|
||||||
<LogsTable {...defaultProps} status="loading" />,
|
<LogsTable {...defaultProps} status="loading" />
|
||||||
);
|
);
|
||||||
|
|
||||||
const skeletonRows = container.querySelectorAll("tbody tr");
|
const skeletonRows = container.querySelectorAll("tbody tr");
|
||||||
|
|
@ -118,10 +118,10 @@ describe("LogsTable", () => {
|
||||||
{...defaultProps}
|
{...defaultProps}
|
||||||
status="error"
|
status="error"
|
||||||
error={{ name: "Error", message: errorMessage } as Error}
|
error={{ name: "Error", message: errorMessage } as Error}
|
||||||
/>,
|
/>
|
||||||
);
|
);
|
||||||
expect(
|
expect(
|
||||||
screen.getByText("Unable to load chat completions"),
|
screen.getByText("Unable to load chat completions")
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
expect(screen.getByText(errorMessage)).toBeInTheDocument();
|
expect(screen.getByText(errorMessage)).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
|
|
@ -132,29 +132,25 @@ describe("LogsTable", () => {
|
||||||
{...defaultProps}
|
{...defaultProps}
|
||||||
status="error"
|
status="error"
|
||||||
error={{ name: "Error", message: "" } as Error}
|
error={{ name: "Error", message: "" } as Error}
|
||||||
/>,
|
/>
|
||||||
);
|
);
|
||||||
expect(
|
expect(
|
||||||
screen.getByText("Unable to load chat completions"),
|
screen.getByText("Unable to load chat completions")
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
expect(
|
expect(
|
||||||
screen.getByText(
|
screen.getByText("An unexpected error occurred while loading the data.")
|
||||||
"An unexpected error occurred while loading the data.",
|
|
||||||
),
|
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
|
|
||||||
test("renders default error message when error prop is an object without message", () => {
|
test("renders default error message when error prop is an object without message", () => {
|
||||||
render(
|
render(
|
||||||
<LogsTable {...defaultProps} status="error" error={{} as Error} />,
|
<LogsTable {...defaultProps} status="error" error={{} as Error} />
|
||||||
);
|
);
|
||||||
expect(
|
expect(
|
||||||
screen.getByText("Unable to load chat completions"),
|
screen.getByText("Unable to load chat completions")
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
expect(
|
expect(
|
||||||
screen.getByText(
|
screen.getByText("An unexpected error occurred while loading the data.")
|
||||||
"An unexpected error occurred while loading the data.",
|
|
||||||
),
|
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -164,7 +160,7 @@ describe("LogsTable", () => {
|
||||||
{...defaultProps}
|
{...defaultProps}
|
||||||
status="error"
|
status="error"
|
||||||
error={{ name: "Error", message: "Test error" } as Error}
|
error={{ name: "Error", message: "Test error" } as Error}
|
||||||
/>,
|
/>
|
||||||
);
|
);
|
||||||
const table = screen.queryByRole("table");
|
const table = screen.queryByRole("table");
|
||||||
expect(table).not.toBeInTheDocument();
|
expect(table).not.toBeInTheDocument();
|
||||||
|
|
@ -178,7 +174,7 @@ describe("LogsTable", () => {
|
||||||
{...defaultProps}
|
{...defaultProps}
|
||||||
data={[]}
|
data={[]}
|
||||||
emptyMessage="Custom empty message"
|
emptyMessage="Custom empty message"
|
||||||
/>,
|
/>
|
||||||
);
|
);
|
||||||
expect(screen.getByText("Custom empty message")).toBeInTheDocument();
|
expect(screen.getByText("Custom empty message")).toBeInTheDocument();
|
||||||
|
|
||||||
|
|
@ -214,7 +210,7 @@ describe("LogsTable", () => {
|
||||||
{...defaultProps}
|
{...defaultProps}
|
||||||
data={mockData}
|
data={mockData}
|
||||||
caption="Custom table caption"
|
caption="Custom table caption"
|
||||||
/>,
|
/>
|
||||||
);
|
);
|
||||||
|
|
||||||
// Table caption
|
// Table caption
|
||||||
|
|
@ -311,8 +307,8 @@ describe("LogsTable", () => {
|
||||||
// Verify truncated text is displayed
|
// Verify truncated text is displayed
|
||||||
const truncatedTexts = screen.getAllByText("This is a ...");
|
const truncatedTexts = screen.getAllByText("This is a ...");
|
||||||
expect(truncatedTexts).toHaveLength(2); // one for input, one for output
|
expect(truncatedTexts).toHaveLength(2); // one for input, one for output
|
||||||
truncatedTexts.forEach((textElement) =>
|
truncatedTexts.forEach(textElement =>
|
||||||
expect(textElement).toBeInTheDocument(),
|
expect(textElement).toBeInTheDocument()
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -332,12 +328,12 @@ describe("LogsTable", () => {
|
||||||
|
|
||||||
// Model name should not be passed to truncateText
|
// Model name should not be passed to truncateText
|
||||||
expect(truncateText).not.toHaveBeenCalledWith(
|
expect(truncateText).not.toHaveBeenCalledWith(
|
||||||
"very-long-model-name-that-should-not-be-truncated",
|
"very-long-model-name-that-should-not-be-truncated"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Full model name should be displayed
|
// Full model name should be displayed
|
||||||
expect(
|
expect(
|
||||||
screen.getByText("very-long-model-name-that-should-not-be-truncated"),
|
screen.getByText("very-long-model-name-that-should-not-be-truncated")
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -142,7 +142,7 @@ export function LogsTable({
|
||||||
<Table>
|
<Table>
|
||||||
<TableCaption className="sr-only">{caption}</TableCaption>
|
<TableCaption className="sr-only">{caption}</TableCaption>
|
||||||
<TableBody>
|
<TableBody>
|
||||||
{data.map((row) => (
|
{data.map(row => (
|
||||||
<TableRow
|
<TableRow
|
||||||
key={row.id}
|
key={row.id}
|
||||||
onClick={() => router.push(row.detailPath)}
|
onClick={() => router.push(row.detailPath)}
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ export function GroupedItemsDisplay({
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
{groupedItems.map((groupedItem) => {
|
{groupedItems.map(groupedItem => {
|
||||||
// If this is a function call with an output, render the grouped component
|
// If this is a function call with an output, render the grouped component
|
||||||
if (
|
if (
|
||||||
groupedItem.outputItem &&
|
groupedItem.outputItem &&
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ export interface GroupedItem {
|
||||||
* @returns Array of grouped items with their outputs
|
* @returns Array of grouped items with their outputs
|
||||||
*/
|
*/
|
||||||
export function useFunctionCallGrouping(
|
export function useFunctionCallGrouping(
|
||||||
items: AnyResponseItem[],
|
items: AnyResponseItem[]
|
||||||
): GroupedItem[] {
|
): GroupedItem[] {
|
||||||
return useMemo(() => {
|
return useMemo(() => {
|
||||||
const groupedItems: GroupedItem[] = [];
|
const groupedItems: GroupedItem[] = [];
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,7 @@ export function ItemRenderer({
|
||||||
// Fallback to generic item for unknown types
|
// Fallback to generic item for unknown types
|
||||||
return (
|
return (
|
||||||
<GenericItemComponent
|
<GenericItemComponent
|
||||||
item={item as any}
|
item={item as Record<string, unknown>}
|
||||||
index={index}
|
index={index}
|
||||||
keyPrefix={keyPrefix}
|
keyPrefix={keyPrefix}
|
||||||
/>
|
/>
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ export function MessageItemComponent({
|
||||||
content = item.content;
|
content = item.content;
|
||||||
} else if (Array.isArray(item.content)) {
|
} else if (Array.isArray(item.content)) {
|
||||||
content = item.content
|
content = item.content
|
||||||
.map((c) => {
|
.map(c => {
|
||||||
return c.type === "input_text" || c.type === "output_text"
|
return c.type === "input_text" || c.type === "output_text"
|
||||||
? c.text
|
? c.text
|
||||||
: JSON.stringify(c);
|
: JSON.stringify(c);
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ describe("ResponseDetailView", () => {
|
||||||
describe("Loading State", () => {
|
describe("Loading State", () => {
|
||||||
test("renders loading skeleton when isLoading is true", () => {
|
test("renders loading skeleton when isLoading is true", () => {
|
||||||
const { container } = render(
|
const { container } = render(
|
||||||
<ResponseDetailView {...defaultProps} isLoading={true} />,
|
<ResponseDetailView {...defaultProps} isLoading={true} />
|
||||||
);
|
);
|
||||||
|
|
||||||
// Check for skeleton elements
|
// Check for skeleton elements
|
||||||
|
|
@ -36,13 +36,13 @@ describe("ResponseDetailView", () => {
|
||||||
<ResponseDetailView
|
<ResponseDetailView
|
||||||
{...defaultProps}
|
{...defaultProps}
|
||||||
error={{ name: "Error", message: errorMessage }}
|
error={{ name: "Error", message: errorMessage }}
|
||||||
/>,
|
/>
|
||||||
);
|
);
|
||||||
|
|
||||||
expect(screen.getByText("Responses Details")).toBeInTheDocument();
|
expect(screen.getByText("Responses Details")).toBeInTheDocument();
|
||||||
// The error message is split across elements, so we check for parts
|
// The error message is split across elements, so we check for parts
|
||||||
expect(
|
expect(
|
||||||
screen.getByText(/Error loading details for ID/),
|
screen.getByText(/Error loading details for ID/)
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
expect(screen.getByText(/test_id/)).toBeInTheDocument();
|
expect(screen.getByText(/test_id/)).toBeInTheDocument();
|
||||||
expect(screen.getByText(/Network Error/)).toBeInTheDocument();
|
expect(screen.getByText(/Network Error/)).toBeInTheDocument();
|
||||||
|
|
@ -53,11 +53,11 @@ describe("ResponseDetailView", () => {
|
||||||
<ResponseDetailView
|
<ResponseDetailView
|
||||||
{...defaultProps}
|
{...defaultProps}
|
||||||
error={{ name: "Error", message: "" }}
|
error={{ name: "Error", message: "" }}
|
||||||
/>,
|
/>
|
||||||
);
|
);
|
||||||
|
|
||||||
expect(
|
expect(
|
||||||
screen.getByText(/Error loading details for ID/),
|
screen.getByText(/Error loading details for ID/)
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
expect(screen.getByText(/test_id/)).toBeInTheDocument();
|
expect(screen.getByText(/test_id/)).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
|
|
@ -124,14 +124,14 @@ describe("ResponseDetailView", () => {
|
||||||
// Check properties - use regex to handle text split across elements
|
// Check properties - use regex to handle text split across elements
|
||||||
expect(screen.getByText(/Created/)).toBeInTheDocument();
|
expect(screen.getByText(/Created/)).toBeInTheDocument();
|
||||||
expect(
|
expect(
|
||||||
screen.getByText(new Date(1710000000 * 1000).toLocaleString()),
|
screen.getByText(new Date(1710000000 * 1000).toLocaleString())
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
|
|
||||||
// Check for the specific ID label (not Previous Response ID)
|
// Check for the specific ID label (not Previous Response ID)
|
||||||
expect(
|
expect(
|
||||||
screen.getByText((content, element) => {
|
screen.getByText((content, element) => {
|
||||||
return element?.tagName === "STRONG" && content === "ID:";
|
return element?.tagName === "STRONG" && content === "ID:";
|
||||||
}),
|
})
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
expect(screen.getByText("resp_123")).toBeInTheDocument();
|
expect(screen.getByText("resp_123")).toBeInTheDocument();
|
||||||
|
|
||||||
|
|
@ -166,7 +166,7 @@ describe("ResponseDetailView", () => {
|
||||||
};
|
};
|
||||||
|
|
||||||
render(
|
render(
|
||||||
<ResponseDetailView {...defaultProps} response={minimalResponse} />,
|
<ResponseDetailView {...defaultProps} response={minimalResponse} />
|
||||||
);
|
);
|
||||||
|
|
||||||
// Should show required properties
|
// Should show required properties
|
||||||
|
|
@ -179,7 +179,7 @@ describe("ResponseDetailView", () => {
|
||||||
expect(screen.queryByText("Top P")).not.toBeInTheDocument();
|
expect(screen.queryByText("Top P")).not.toBeInTheDocument();
|
||||||
expect(screen.queryByText("Parallel Tool Calls")).not.toBeInTheDocument();
|
expect(screen.queryByText("Parallel Tool Calls")).not.toBeInTheDocument();
|
||||||
expect(
|
expect(
|
||||||
screen.queryByText("Previous Response ID"),
|
screen.queryByText("Previous Response ID")
|
||||||
).not.toBeInTheDocument();
|
).not.toBeInTheDocument();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -196,7 +196,7 @@ describe("ResponseDetailView", () => {
|
||||||
|
|
||||||
// The error is shown in the properties sidebar, not as a separate "Error" label
|
// The error is shown in the properties sidebar, not as a separate "Error" label
|
||||||
expect(
|
expect(
|
||||||
screen.getByText("invalid_request: The request was invalid"),
|
screen.getByText("invalid_request: The request was invalid")
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
@ -218,7 +218,7 @@ describe("ResponseDetailView", () => {
|
||||||
{...defaultProps}
|
{...defaultProps}
|
||||||
response={mockResponse}
|
response={mockResponse}
|
||||||
isLoadingInputItems={true}
|
isLoadingInputItems={true}
|
||||||
/>,
|
/>
|
||||||
);
|
);
|
||||||
|
|
||||||
// Check for skeleton loading in input items section
|
// Check for skeleton loading in input items section
|
||||||
|
|
@ -227,7 +227,7 @@ describe("ResponseDetailView", () => {
|
||||||
{...defaultProps}
|
{...defaultProps}
|
||||||
response={mockResponse}
|
response={mockResponse}
|
||||||
isLoadingInputItems={true}
|
isLoadingInputItems={true}
|
||||||
/>,
|
/>
|
||||||
);
|
);
|
||||||
|
|
||||||
const skeletons = container.querySelectorAll('[data-slot="skeleton"]');
|
const skeletons = container.querySelectorAll('[data-slot="skeleton"]');
|
||||||
|
|
@ -243,16 +243,16 @@ describe("ResponseDetailView", () => {
|
||||||
name: "Error",
|
name: "Error",
|
||||||
message: "Failed to load input items",
|
message: "Failed to load input items",
|
||||||
}}
|
}}
|
||||||
/>,
|
/>
|
||||||
);
|
);
|
||||||
|
|
||||||
expect(
|
expect(
|
||||||
screen.getByText(
|
screen.getByText(
|
||||||
"Error loading input items: Failed to load input items",
|
"Error loading input items: Failed to load input items"
|
||||||
),
|
)
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
expect(
|
expect(
|
||||||
screen.getByText("Falling back to response input data."),
|
screen.getByText("Falling back to response input data.")
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
|
|
||||||
// Should still show fallback input data
|
// Should still show fallback input data
|
||||||
|
|
@ -276,7 +276,7 @@ describe("ResponseDetailView", () => {
|
||||||
{...defaultProps}
|
{...defaultProps}
|
||||||
response={mockResponse}
|
response={mockResponse}
|
||||||
inputItems={mockInputItems}
|
inputItems={mockInputItems}
|
||||||
/>,
|
/>
|
||||||
);
|
);
|
||||||
|
|
||||||
// Should show input items data, not response.input
|
// Should show input items data, not response.input
|
||||||
|
|
@ -295,7 +295,7 @@ describe("ResponseDetailView", () => {
|
||||||
{...defaultProps}
|
{...defaultProps}
|
||||||
response={mockResponse}
|
response={mockResponse}
|
||||||
inputItems={emptyInputItems}
|
inputItems={emptyInputItems}
|
||||||
/>,
|
/>
|
||||||
);
|
);
|
||||||
|
|
||||||
// Should show fallback input data
|
// Should show fallback input data
|
||||||
|
|
@ -313,7 +313,7 @@ describe("ResponseDetailView", () => {
|
||||||
{...defaultProps}
|
{...defaultProps}
|
||||||
response={responseWithoutInput}
|
response={responseWithoutInput}
|
||||||
inputItems={null}
|
inputItems={null}
|
||||||
/>,
|
/>
|
||||||
);
|
);
|
||||||
|
|
||||||
expect(screen.getByText("No input data available.")).toBeInTheDocument();
|
expect(screen.getByText("No input data available.")).toBeInTheDocument();
|
||||||
|
|
@ -443,7 +443,7 @@ describe("ResponseDetailView", () => {
|
||||||
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||||
|
|
||||||
expect(
|
expect(
|
||||||
screen.getByText('input_function({"param": "value"})'),
|
screen.getByText('input_function({"param": "value"})')
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
expect(screen.getByText("Function Call")).toBeInTheDocument();
|
expect(screen.getByText("Function Call")).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
|
|
@ -468,7 +468,7 @@ describe("ResponseDetailView", () => {
|
||||||
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||||
|
|
||||||
expect(
|
expect(
|
||||||
screen.getByText("web_search_call(status: completed)"),
|
screen.getByText("web_search_call(status: completed)")
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
expect(screen.getByText("Function Call")).toBeInTheDocument();
|
expect(screen.getByText("Function Call")).toBeInTheDocument();
|
||||||
expect(screen.getByText("(Web Search)")).toBeInTheDocument();
|
expect(screen.getByText("(Web Search)")).toBeInTheDocument();
|
||||||
|
|
@ -522,7 +522,7 @@ describe("ResponseDetailView", () => {
|
||||||
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||||
|
|
||||||
expect(
|
expect(
|
||||||
screen.getByText("First output Second output"),
|
screen.getByText("First output Second output")
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
expect(screen.getByText("Assistant")).toBeInTheDocument();
|
expect(screen.getByText("Assistant")).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
|
|
@ -549,7 +549,7 @@ describe("ResponseDetailView", () => {
|
||||||
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||||
|
|
||||||
expect(
|
expect(
|
||||||
screen.getByText('search_function({"query": "test"})'),
|
screen.getByText('search_function({"query": "test"})')
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
expect(screen.getByText("Function Call")).toBeInTheDocument();
|
expect(screen.getByText("Function Call")).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
|
|
@ -598,7 +598,7 @@ describe("ResponseDetailView", () => {
|
||||||
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||||
|
|
||||||
expect(
|
expect(
|
||||||
screen.getByText("web_search_call(status: completed)"),
|
screen.getByText("web_search_call(status: completed)")
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
expect(screen.getByText(/Function Call/)).toBeInTheDocument();
|
expect(screen.getByText(/Function Call/)).toBeInTheDocument();
|
||||||
expect(screen.getByText("(Web Search)")).toBeInTheDocument();
|
expect(screen.getByText("(Web Search)")).toBeInTheDocument();
|
||||||
|
|
@ -616,7 +616,7 @@ describe("ResponseDetailView", () => {
|
||||||
type: "unknown_type",
|
type: "unknown_type",
|
||||||
custom_field: "custom_value",
|
custom_field: "custom_value",
|
||||||
data: { nested: "object" },
|
data: { nested: "object" },
|
||||||
} as any,
|
} as unknown,
|
||||||
],
|
],
|
||||||
input: [],
|
input: [],
|
||||||
};
|
};
|
||||||
|
|
@ -625,7 +625,7 @@ describe("ResponseDetailView", () => {
|
||||||
|
|
||||||
// Should show JSON stringified content
|
// Should show JSON stringified content
|
||||||
expect(
|
expect(
|
||||||
screen.getByText(/custom_field.*custom_value/),
|
screen.getByText(/custom_field.*custom_value/)
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
expect(screen.getByText("(unknown_type)")).toBeInTheDocument();
|
expect(screen.getByText("(unknown_type)")).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
|
|
@ -666,7 +666,7 @@ describe("ResponseDetailView", () => {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
call_id: "call_123",
|
call_id: "call_123",
|
||||||
content: "sunny and warm",
|
content: "sunny and warm",
|
||||||
} as any, // Using any to bypass the type restriction for this test
|
} as unknown, // Using any to bypass the type restriction for this test
|
||||||
],
|
],
|
||||||
input: [],
|
input: [],
|
||||||
};
|
};
|
||||||
|
|
@ -676,7 +676,7 @@ describe("ResponseDetailView", () => {
|
||||||
// Should show the function call and message as separate items (not grouped)
|
// Should show the function call and message as separate items (not grouped)
|
||||||
expect(screen.getByText("Function Call")).toBeInTheDocument();
|
expect(screen.getByText("Function Call")).toBeInTheDocument();
|
||||||
expect(
|
expect(
|
||||||
screen.getByText('get_weather({"city": "Tokyo"})'),
|
screen.getByText('get_weather({"city": "Tokyo"})')
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
expect(screen.getByText("Assistant")).toBeInTheDocument();
|
expect(screen.getByText("Assistant")).toBeInTheDocument();
|
||||||
expect(screen.getByText("sunny and warm")).toBeInTheDocument();
|
expect(screen.getByText("sunny and warm")).toBeInTheDocument();
|
||||||
|
|
@ -706,7 +706,7 @@ describe("ResponseDetailView", () => {
|
||||||
status: "completed",
|
status: "completed",
|
||||||
call_id: "call_123",
|
call_id: "call_123",
|
||||||
output: "sunny and warm",
|
output: "sunny and warm",
|
||||||
} as any, // Using any to bypass the type restriction for this test
|
} as unknown,
|
||||||
],
|
],
|
||||||
input: [],
|
input: [],
|
||||||
};
|
};
|
||||||
|
|
@ -717,7 +717,7 @@ describe("ResponseDetailView", () => {
|
||||||
expect(screen.getByText("Function Call")).toBeInTheDocument();
|
expect(screen.getByText("Function Call")).toBeInTheDocument();
|
||||||
expect(screen.getByText("Arguments")).toBeInTheDocument();
|
expect(screen.getByText("Arguments")).toBeInTheDocument();
|
||||||
expect(
|
expect(
|
||||||
screen.getByText('get_weather({"city": "Tokyo"})'),
|
screen.getByText('get_weather({"city": "Tokyo"})')
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
// Use getAllByText since there are multiple "Output" elements (card title and output label)
|
// Use getAllByText since there are multiple "Output" elements (card title and output label)
|
||||||
const outputElements = screen.getAllByText("Output");
|
const outputElements = screen.getAllByText("Output");
|
||||||
|
|
|
||||||
|
|
@ -146,7 +146,7 @@ describe("ResponsesTable", () => {
|
||||||
expect(tableCaption).toBeInTheDocument();
|
expect(tableCaption).toBeInTheDocument();
|
||||||
if (tableCaption) {
|
if (tableCaption) {
|
||||||
const captionSkeleton = tableCaption.querySelector(
|
const captionSkeleton = tableCaption.querySelector(
|
||||||
'[data-slot="skeleton"]',
|
'[data-slot="skeleton"]'
|
||||||
);
|
);
|
||||||
expect(captionSkeleton).toBeInTheDocument();
|
expect(captionSkeleton).toBeInTheDocument();
|
||||||
}
|
}
|
||||||
|
|
@ -156,7 +156,7 @@ describe("ResponsesTable", () => {
|
||||||
expect(tableBody).toBeInTheDocument();
|
expect(tableBody).toBeInTheDocument();
|
||||||
if (tableBody) {
|
if (tableBody) {
|
||||||
const bodySkeletons = tableBody.querySelectorAll(
|
const bodySkeletons = tableBody.querySelectorAll(
|
||||||
'[data-slot="skeleton"]',
|
'[data-slot="skeleton"]'
|
||||||
);
|
);
|
||||||
expect(bodySkeletons.length).toBeGreaterThan(0);
|
expect(bodySkeletons.length).toBeGreaterThan(0);
|
||||||
}
|
}
|
||||||
|
|
@ -176,14 +176,14 @@ describe("ResponsesTable", () => {
|
||||||
|
|
||||||
render(<ResponsesTable {...defaultProps} />);
|
render(<ResponsesTable {...defaultProps} />);
|
||||||
expect(
|
expect(
|
||||||
screen.getByText("Unable to load chat completions"),
|
screen.getByText("Unable to load chat completions")
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
expect(screen.getByText(errorMessage)).toBeInTheDocument();
|
expect(screen.getByText(errorMessage)).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
|
|
||||||
test.each([{ name: "Error", message: "" }, {}])(
|
test.each([{ name: "Error", message: "" }, {}])(
|
||||||
"renders default error message when error has no message",
|
"renders default error message when error has no message",
|
||||||
(errorObject) => {
|
errorObject => {
|
||||||
mockedUsePagination.mockReturnValue({
|
mockedUsePagination.mockReturnValue({
|
||||||
data: [],
|
data: [],
|
||||||
status: "error",
|
status: "error",
|
||||||
|
|
@ -194,14 +194,14 @@ describe("ResponsesTable", () => {
|
||||||
|
|
||||||
render(<ResponsesTable {...defaultProps} />);
|
render(<ResponsesTable {...defaultProps} />);
|
||||||
expect(
|
expect(
|
||||||
screen.getByText("Unable to load chat completions"),
|
screen.getByText("Unable to load chat completions")
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
expect(
|
expect(
|
||||||
screen.getByText(
|
screen.getByText(
|
||||||
"An unexpected error occurred while loading the data.",
|
"An unexpected error occurred while loading the data."
|
||||||
),
|
)
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
},
|
}
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -275,7 +275,7 @@ describe("ResponsesTable", () => {
|
||||||
|
|
||||||
// Table caption
|
// Table caption
|
||||||
expect(
|
expect(
|
||||||
screen.getByText("A list of your recent responses."),
|
screen.getByText("A list of your recent responses.")
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
|
|
||||||
// Table headers
|
// Table headers
|
||||||
|
|
@ -289,14 +289,14 @@ describe("ResponsesTable", () => {
|
||||||
expect(screen.getByText("Test output")).toBeInTheDocument();
|
expect(screen.getByText("Test output")).toBeInTheDocument();
|
||||||
expect(screen.getByText("llama-test-model")).toBeInTheDocument();
|
expect(screen.getByText("llama-test-model")).toBeInTheDocument();
|
||||||
expect(
|
expect(
|
||||||
screen.getByText(new Date(1710000000 * 1000).toLocaleString()),
|
screen.getByText(new Date(1710000000 * 1000).toLocaleString())
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
|
|
||||||
expect(screen.getByText("Another input")).toBeInTheDocument();
|
expect(screen.getByText("Another input")).toBeInTheDocument();
|
||||||
expect(screen.getByText("Another output")).toBeInTheDocument();
|
expect(screen.getByText("Another output")).toBeInTheDocument();
|
||||||
expect(screen.getByText("llama-another-model")).toBeInTheDocument();
|
expect(screen.getByText("llama-another-model")).toBeInTheDocument();
|
||||||
expect(
|
expect(
|
||||||
screen.getByText(new Date(1710001000 * 1000).toLocaleString()),
|
screen.getByText(new Date(1710001000 * 1000).toLocaleString())
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
@ -487,7 +487,7 @@ describe("ResponsesTable", () => {
|
||||||
|
|
||||||
render(<ResponsesTable {...defaultProps} />);
|
render(<ResponsesTable {...defaultProps} />);
|
||||||
expect(
|
expect(
|
||||||
screen.getByText('search_function({"query": "test"})'),
|
screen.getByText('search_function({"query": "test"})')
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -548,7 +548,7 @@ describe("ResponsesTable", () => {
|
||||||
|
|
||||||
render(<ResponsesTable {...defaultProps} />);
|
render(<ResponsesTable {...defaultProps} />);
|
||||||
expect(
|
expect(
|
||||||
screen.getByText("web_search_call(status: completed)"),
|
screen.getByText("web_search_call(status: completed)")
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -565,7 +565,7 @@ describe("ResponsesTable", () => {
|
||||||
id: "unknown_123",
|
id: "unknown_123",
|
||||||
status: "completed",
|
status: "completed",
|
||||||
custom_field: "custom_value",
|
custom_field: "custom_value",
|
||||||
} as any,
|
} as unknown,
|
||||||
],
|
],
|
||||||
input: [{ type: "message", content: "input" }],
|
input: [{ type: "message", content: "input" }],
|
||||||
};
|
};
|
||||||
|
|
@ -594,7 +594,7 @@ describe("ResponsesTable", () => {
|
||||||
{
|
{
|
||||||
type: "unknown_type",
|
type: "unknown_type",
|
||||||
data: "some data",
|
data: "some data",
|
||||||
} as any,
|
} as unknown,
|
||||||
],
|
],
|
||||||
input: [{ type: "message", content: "input" }],
|
input: [{ type: "message", content: "input" }],
|
||||||
};
|
};
|
||||||
|
|
@ -623,7 +623,7 @@ describe("ResponsesTable", () => {
|
||||||
return typeof text === "string" && text.length > effectiveMaxLength
|
return typeof text === "string" && text.length > effectiveMaxLength
|
||||||
? text.slice(0, effectiveMaxLength) + "..."
|
? text.slice(0, effectiveMaxLength) + "..."
|
||||||
: text;
|
: text;
|
||||||
},
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
const longInput =
|
const longInput =
|
||||||
|
|
@ -665,7 +665,7 @@ describe("ResponsesTable", () => {
|
||||||
|
|
||||||
// The truncated text should be present for both input and output
|
// The truncated text should be present for both input and output
|
||||||
const truncatedTexts = screen.getAllByText(
|
const truncatedTexts = screen.getAllByText(
|
||||||
longInput.slice(0, 10) + "...",
|
longInput.slice(0, 10) + "..."
|
||||||
);
|
);
|
||||||
expect(truncatedTexts.length).toBe(2); // one for input, one for output
|
expect(truncatedTexts.length).toBe(2); // one for input, one for output
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ interface ResponsesTableProps {
|
||||||
* Helper function to convert ResponseListResponse.Data to OpenAIResponse
|
* Helper function to convert ResponseListResponse.Data to OpenAIResponse
|
||||||
*/
|
*/
|
||||||
const convertResponseListData = (
|
const convertResponseListData = (
|
||||||
responseData: ResponseListResponse.Data,
|
responseData: ResponseListResponse.Data
|
||||||
): OpenAIResponse => {
|
): OpenAIResponse => {
|
||||||
return {
|
return {
|
||||||
id: responseData.id,
|
id: responseData.id,
|
||||||
|
|
@ -56,8 +56,8 @@ function getInputText(response: OpenAIResponse): string {
|
||||||
}
|
}
|
||||||
|
|
||||||
function getOutputText(response: OpenAIResponse): string {
|
function getOutputText(response: OpenAIResponse): string {
|
||||||
const firstMessage = response.output.find((item) =>
|
const firstMessage = response.output.find(item =>
|
||||||
isMessageItem(item as any),
|
isMessageItem(item as Record<string, unknown>)
|
||||||
);
|
);
|
||||||
if (firstMessage) {
|
if (firstMessage) {
|
||||||
const content = extractContentFromItem(firstMessage as MessageItem);
|
const content = extractContentFromItem(firstMessage as MessageItem);
|
||||||
|
|
@ -66,15 +66,15 @@ function getOutputText(response: OpenAIResponse): string {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const functionCall = response.output.find((item) =>
|
const functionCall = response.output.find(item =>
|
||||||
isFunctionCallItem(item as any),
|
isFunctionCallItem(item as Record<string, unknown>)
|
||||||
);
|
);
|
||||||
if (functionCall) {
|
if (functionCall) {
|
||||||
return formatFunctionCall(functionCall as FunctionCallItem);
|
return formatFunctionCall(functionCall as FunctionCallItem);
|
||||||
}
|
}
|
||||||
|
|
||||||
const webSearchCall = response.output.find((item) =>
|
const webSearchCall = response.output.find(item =>
|
||||||
isWebSearchCallItem(item as any),
|
isWebSearchCallItem(item as Record<string, unknown>)
|
||||||
);
|
);
|
||||||
if (webSearchCall) {
|
if (webSearchCall) {
|
||||||
return formatWebSearchCall(webSearchCall as WebSearchCallItem);
|
return formatWebSearchCall(webSearchCall as WebSearchCallItem);
|
||||||
|
|
@ -95,7 +95,7 @@ function extractContentFromItem(item: {
|
||||||
} else if (Array.isArray(item.content)) {
|
} else if (Array.isArray(item.content)) {
|
||||||
const textContent = item.content.find(
|
const textContent = item.content.find(
|
||||||
(c: ResponseInputMessageContent) =>
|
(c: ResponseInputMessageContent) =>
|
||||||
c.type === "input_text" || c.type === "output_text",
|
c.type === "input_text" || c.type === "output_text"
|
||||||
);
|
);
|
||||||
return textContent?.text || "";
|
return textContent?.text || "";
|
||||||
}
|
}
|
||||||
|
|
@ -131,14 +131,14 @@ export function ResponsesTable({ paginationOptions }: ResponsesTableProps) {
|
||||||
limit: number;
|
limit: number;
|
||||||
model?: string;
|
model?: string;
|
||||||
order?: string;
|
order?: string;
|
||||||
},
|
}
|
||||||
) => {
|
) => {
|
||||||
const response = await client.responses.list({
|
const response = await client.responses.list({
|
||||||
after: params.after,
|
after: params.after,
|
||||||
limit: params.limit,
|
limit: params.limit,
|
||||||
...(params.model && { model: params.model }),
|
...(params.model && { model: params.model }),
|
||||||
...(params.order && { order: params.order }),
|
...(params.order && { order: params.order }),
|
||||||
} as any);
|
} as Parameters<typeof client.responses.list>[0]);
|
||||||
|
|
||||||
const listResponse = response as ResponseListResponse;
|
const listResponse = response as ResponseListResponse;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ export type AnyResponseItem =
|
||||||
| FunctionCallOutputItem;
|
| FunctionCallOutputItem;
|
||||||
|
|
||||||
export function isMessageInput(
|
export function isMessageInput(
|
||||||
item: ResponseInput,
|
item: ResponseInput
|
||||||
): item is ResponseInput & { type: "message" } {
|
): item is ResponseInput & { type: "message" } {
|
||||||
return item.type === "message";
|
return item.type === "message";
|
||||||
}
|
}
|
||||||
|
|
@ -39,23 +39,23 @@ export function isMessageItem(item: AnyResponseItem): item is MessageItem {
|
||||||
}
|
}
|
||||||
|
|
||||||
export function isFunctionCallItem(
|
export function isFunctionCallItem(
|
||||||
item: AnyResponseItem,
|
item: AnyResponseItem
|
||||||
): item is FunctionCallItem {
|
): item is FunctionCallItem {
|
||||||
return item.type === "function_call" && "name" in item;
|
return item.type === "function_call" && "name" in item;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function isWebSearchCallItem(
|
export function isWebSearchCallItem(
|
||||||
item: AnyResponseItem,
|
item: AnyResponseItem
|
||||||
): item is WebSearchCallItem {
|
): item is WebSearchCallItem {
|
||||||
return item.type === "web_search_call";
|
return item.type === "web_search_call";
|
||||||
}
|
}
|
||||||
|
|
||||||
export function isFunctionCallOutputItem(
|
export function isFunctionCallOutputItem(
|
||||||
item: AnyResponseItem,
|
item: AnyResponseItem
|
||||||
): item is FunctionCallOutputItem {
|
): item is FunctionCallOutputItem {
|
||||||
return (
|
return (
|
||||||
item.type === "function_call_output" &&
|
item.type === "function_call_output" &&
|
||||||
"call_id" in item &&
|
"call_id" in item &&
|
||||||
typeof (item as any).call_id === "string"
|
typeof (item as Record<string, unknown>).call_id === "string"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"use client"
|
"use client";
|
||||||
|
|
||||||
import { useEffect, useRef } from "react"
|
import { useEffect, useRef } from "react";
|
||||||
|
|
||||||
// Configuration constants for the audio analyzer
|
// Configuration constants for the audio analyzer
|
||||||
const AUDIO_CONFIG = {
|
const AUDIO_CONFIG = {
|
||||||
|
|
@ -14,12 +14,12 @@ const AUDIO_CONFIG = {
|
||||||
MAX_INTENSITY: 255, // Maximum gray value (brighter)
|
MAX_INTENSITY: 255, // Maximum gray value (brighter)
|
||||||
INTENSITY_RANGE: 155, // MAX_INTENSITY - MIN_INTENSITY
|
INTENSITY_RANGE: 155, // MAX_INTENSITY - MIN_INTENSITY
|
||||||
},
|
},
|
||||||
} as const
|
} as const;
|
||||||
|
|
||||||
interface AudioVisualizerProps {
|
interface AudioVisualizerProps {
|
||||||
stream: MediaStream | null
|
stream: MediaStream | null;
|
||||||
isRecording: boolean
|
isRecording: boolean;
|
||||||
onClick: () => void
|
onClick: () => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function AudioVisualizer({
|
export function AudioVisualizer({
|
||||||
|
|
@ -28,91 +28,91 @@ export function AudioVisualizer({
|
||||||
onClick,
|
onClick,
|
||||||
}: AudioVisualizerProps) {
|
}: AudioVisualizerProps) {
|
||||||
// Refs for managing audio context and animation
|
// Refs for managing audio context and animation
|
||||||
const canvasRef = useRef<HTMLCanvasElement>(null)
|
const canvasRef = useRef<HTMLCanvasElement>(null);
|
||||||
const audioContextRef = useRef<AudioContext | null>(null)
|
const audioContextRef = useRef<AudioContext | null>(null);
|
||||||
const analyserRef = useRef<AnalyserNode | null>(null)
|
const analyserRef = useRef<AnalyserNode | null>(null);
|
||||||
const animationFrameRef = useRef<number>()
|
const animationFrameRef = useRef<number>();
|
||||||
const containerRef = useRef<HTMLDivElement>(null)
|
const containerRef = useRef<HTMLDivElement>(null);
|
||||||
|
|
||||||
// Cleanup function to stop visualization and close audio context
|
// Cleanup function to stop visualization and close audio context
|
||||||
const cleanup = () => {
|
const cleanup = () => {
|
||||||
if (animationFrameRef.current) {
|
if (animationFrameRef.current) {
|
||||||
cancelAnimationFrame(animationFrameRef.current)
|
cancelAnimationFrame(animationFrameRef.current);
|
||||||
}
|
}
|
||||||
if (audioContextRef.current) {
|
if (audioContextRef.current) {
|
||||||
audioContextRef.current.close()
|
audioContextRef.current.close();
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
// Cleanup on unmount
|
// Cleanup on unmount
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
return cleanup
|
return cleanup;
|
||||||
}, [])
|
}, []);
|
||||||
|
|
||||||
// Start or stop visualization based on recording state
|
// Start or stop visualization based on recording state
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (stream && isRecording) {
|
if (stream && isRecording) {
|
||||||
startVisualization()
|
startVisualization();
|
||||||
} else {
|
} else {
|
||||||
cleanup()
|
cleanup();
|
||||||
}
|
}
|
||||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
}, [stream, isRecording])
|
}, [stream, isRecording]);
|
||||||
|
|
||||||
// Handle window resize
|
// Handle window resize
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const handleResize = () => {
|
const handleResize = () => {
|
||||||
if (canvasRef.current && containerRef.current) {
|
if (canvasRef.current && containerRef.current) {
|
||||||
const container = containerRef.current
|
const container = containerRef.current;
|
||||||
const canvas = canvasRef.current
|
const canvas = canvasRef.current;
|
||||||
const dpr = window.devicePixelRatio || 1
|
const dpr = window.devicePixelRatio || 1;
|
||||||
|
|
||||||
// Set canvas size based on container and device pixel ratio
|
// Set canvas size based on container and device pixel ratio
|
||||||
const rect = container.getBoundingClientRect()
|
const rect = container.getBoundingClientRect();
|
||||||
// Account for the 2px total margin (1px on each side)
|
// Account for the 2px total margin (1px on each side)
|
||||||
canvas.width = (rect.width - 2) * dpr
|
canvas.width = (rect.width - 2) * dpr;
|
||||||
canvas.height = (rect.height - 2) * dpr
|
canvas.height = (rect.height - 2) * dpr;
|
||||||
|
|
||||||
// Scale canvas CSS size to match container minus margins
|
// Scale canvas CSS size to match container minus margins
|
||||||
canvas.style.width = `${rect.width - 2}px`
|
canvas.style.width = `${rect.width - 2}px`;
|
||||||
canvas.style.height = `${rect.height - 2}px`
|
canvas.style.height = `${rect.height - 2}px`;
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
window.addEventListener("resize", handleResize)
|
window.addEventListener("resize", handleResize);
|
||||||
// Initial setup
|
// Initial setup
|
||||||
handleResize()
|
handleResize();
|
||||||
|
|
||||||
return () => window.removeEventListener("resize", handleResize)
|
return () => window.removeEventListener("resize", handleResize);
|
||||||
}, [])
|
}, []);
|
||||||
|
|
||||||
// Initialize audio context and start visualization
|
// Initialize audio context and start visualization
|
||||||
const startVisualization = async () => {
|
const startVisualization = async () => {
|
||||||
try {
|
try {
|
||||||
const audioContext = new AudioContext()
|
const audioContext = new AudioContext();
|
||||||
audioContextRef.current = audioContext
|
audioContextRef.current = audioContext;
|
||||||
|
|
||||||
const analyser = audioContext.createAnalyser()
|
const analyser = audioContext.createAnalyser();
|
||||||
analyser.fftSize = AUDIO_CONFIG.FFT_SIZE
|
analyser.fftSize = AUDIO_CONFIG.FFT_SIZE;
|
||||||
analyser.smoothingTimeConstant = AUDIO_CONFIG.SMOOTHING
|
analyser.smoothingTimeConstant = AUDIO_CONFIG.SMOOTHING;
|
||||||
analyserRef.current = analyser
|
analyserRef.current = analyser;
|
||||||
|
|
||||||
const source = audioContext.createMediaStreamSource(stream!)
|
const source = audioContext.createMediaStreamSource(stream!);
|
||||||
source.connect(analyser)
|
source.connect(analyser);
|
||||||
|
|
||||||
draw()
|
draw();
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Error starting visualization:", error)
|
console.error("Error starting visualization:", error);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
// Calculate the color intensity based on bar height
|
// Calculate the color intensity based on bar height
|
||||||
const getBarColor = (normalizedHeight: number) => {
|
const getBarColor = (normalizedHeight: number) => {
|
||||||
const intensity =
|
const intensity =
|
||||||
Math.floor(normalizedHeight * AUDIO_CONFIG.COLOR.INTENSITY_RANGE) +
|
Math.floor(normalizedHeight * AUDIO_CONFIG.COLOR.INTENSITY_RANGE) +
|
||||||
AUDIO_CONFIG.COLOR.MIN_INTENSITY
|
AUDIO_CONFIG.COLOR.MIN_INTENSITY;
|
||||||
return `rgb(${intensity}, ${intensity}, ${intensity})`
|
return `rgb(${intensity}, ${intensity}, ${intensity})`;
|
||||||
}
|
};
|
||||||
|
|
||||||
// Draw a single bar of the visualizer
|
// Draw a single bar of the visualizer
|
||||||
const drawBar = (
|
const drawBar = (
|
||||||
|
|
@ -123,52 +123,52 @@ export function AudioVisualizer({
|
||||||
height: number,
|
height: number,
|
||||||
color: string
|
color: string
|
||||||
) => {
|
) => {
|
||||||
ctx.fillStyle = color
|
ctx.fillStyle = color;
|
||||||
// Draw upper bar (above center)
|
// Draw upper bar (above center)
|
||||||
ctx.fillRect(x, centerY - height, width, height)
|
ctx.fillRect(x, centerY - height, width, height);
|
||||||
// Draw lower bar (below center)
|
// Draw lower bar (below center)
|
||||||
ctx.fillRect(x, centerY, width, height)
|
ctx.fillRect(x, centerY, width, height);
|
||||||
}
|
};
|
||||||
|
|
||||||
// Main drawing function
|
// Main drawing function
|
||||||
const draw = () => {
|
const draw = () => {
|
||||||
if (!isRecording) return
|
if (!isRecording) return;
|
||||||
|
|
||||||
const canvas = canvasRef.current
|
const canvas = canvasRef.current;
|
||||||
const ctx = canvas?.getContext("2d")
|
const ctx = canvas?.getContext("2d");
|
||||||
if (!canvas || !ctx || !analyserRef.current) return
|
if (!canvas || !ctx || !analyserRef.current) return;
|
||||||
|
|
||||||
const dpr = window.devicePixelRatio || 1
|
const dpr = window.devicePixelRatio || 1;
|
||||||
ctx.scale(dpr, dpr)
|
ctx.scale(dpr, dpr);
|
||||||
|
|
||||||
const analyser = analyserRef.current
|
const analyser = analyserRef.current;
|
||||||
const bufferLength = analyser.frequencyBinCount
|
const bufferLength = analyser.frequencyBinCount;
|
||||||
const frequencyData = new Uint8Array(bufferLength)
|
const frequencyData = new Uint8Array(bufferLength);
|
||||||
|
|
||||||
const drawFrame = () => {
|
const drawFrame = () => {
|
||||||
animationFrameRef.current = requestAnimationFrame(drawFrame)
|
animationFrameRef.current = requestAnimationFrame(drawFrame);
|
||||||
|
|
||||||
// Get current frequency data
|
// Get current frequency data
|
||||||
analyser.getByteFrequencyData(frequencyData)
|
analyser.getByteFrequencyData(frequencyData);
|
||||||
|
|
||||||
// Clear canvas - use CSS pixels for clearing
|
// Clear canvas - use CSS pixels for clearing
|
||||||
ctx.clearRect(0, 0, canvas.width / dpr, canvas.height / dpr)
|
ctx.clearRect(0, 0, canvas.width / dpr, canvas.height / dpr);
|
||||||
|
|
||||||
// Calculate dimensions in CSS pixels
|
// Calculate dimensions in CSS pixels
|
||||||
const barWidth = Math.max(
|
const barWidth = Math.max(
|
||||||
AUDIO_CONFIG.MIN_BAR_WIDTH,
|
AUDIO_CONFIG.MIN_BAR_WIDTH,
|
||||||
canvas.width / dpr / bufferLength - AUDIO_CONFIG.BAR_SPACING
|
canvas.width / dpr / bufferLength - AUDIO_CONFIG.BAR_SPACING
|
||||||
)
|
);
|
||||||
const centerY = canvas.height / dpr / 2
|
const centerY = canvas.height / dpr / 2;
|
||||||
let x = 0
|
let x = 0;
|
||||||
|
|
||||||
// Draw each frequency bar
|
// Draw each frequency bar
|
||||||
for (let i = 0; i < bufferLength; i++) {
|
for (let i = 0; i < bufferLength; i++) {
|
||||||
const normalizedHeight = frequencyData[i] / 255 // Convert to 0-1 range
|
const normalizedHeight = frequencyData[i] / 255; // Convert to 0-1 range
|
||||||
const barHeight = Math.max(
|
const barHeight = Math.max(
|
||||||
AUDIO_CONFIG.MIN_BAR_HEIGHT,
|
AUDIO_CONFIG.MIN_BAR_HEIGHT,
|
||||||
normalizedHeight * centerY
|
normalizedHeight * centerY
|
||||||
)
|
);
|
||||||
|
|
||||||
drawBar(
|
drawBar(
|
||||||
ctx,
|
ctx,
|
||||||
|
|
@ -177,14 +177,14 @@ export function AudioVisualizer({
|
||||||
barWidth,
|
barWidth,
|
||||||
barHeight,
|
barHeight,
|
||||||
getBarColor(normalizedHeight)
|
getBarColor(normalizedHeight)
|
||||||
)
|
);
|
||||||
|
|
||||||
x += barWidth + AUDIO_CONFIG.BAR_SPACING
|
x += barWidth + AUDIO_CONFIG.BAR_SPACING;
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
drawFrame()
|
drawFrame();
|
||||||
}
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
|
|
@ -194,5 +194,5 @@ export function AudioVisualizer({
|
||||||
>
|
>
|
||||||
<canvas ref={canvasRef} className="h-full w-full" />
|
<canvas ref={canvasRef} className="h-full w-full" />
|
||||||
</div>
|
</div>
|
||||||
)
|
);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ function BreadcrumbList({ className, ...props }: React.ComponentProps<"ol">) {
|
||||||
data-slot="breadcrumb-list"
|
data-slot="breadcrumb-list"
|
||||||
className={cn(
|
className={cn(
|
||||||
"text-muted-foreground flex flex-wrap items-center gap-1.5 text-sm break-words sm:gap-2.5",
|
"text-muted-foreground flex flex-wrap items-center gap-1.5 text-sm break-words sm:gap-2.5",
|
||||||
className,
|
className
|
||||||
)}
|
)}
|
||||||
{...props}
|
{...props}
|
||||||
/>
|
/>
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
import * as React from "react"
|
import * as React from "react";
|
||||||
import { Slot } from "@radix-ui/react-slot"
|
import { Slot } from "@radix-ui/react-slot";
|
||||||
import { cva, type VariantProps } from "class-variance-authority"
|
import { cva, type VariantProps } from "class-variance-authority";
|
||||||
|
|
||||||
import { cn } from "@/lib/utils"
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
const buttonVariants = cva(
|
const buttonVariants = cva(
|
||||||
"inline-flex items-center justify-center gap-2 whitespace-nowrap rounded-md text-sm font-medium transition-all disabled:pointer-events-none disabled:opacity-50 [&_svg]:pointer-events-none [&_svg:not([class*='size-'])]:size-4 shrink-0 [&_svg]:shrink-0 outline-none focus-visible:border-ring focus-visible:ring-ring/50 focus-visible:ring-[3px] aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive",
|
"inline-flex items-center justify-center gap-2 whitespace-nowrap rounded-md text-sm font-medium transition-all disabled:pointer-events-none disabled:opacity-50 [&_svg]:pointer-events-none [&_svg:not([class*='size-'])]:size-4 shrink-0 [&_svg]:shrink-0 outline-none focus-visible:border-ring focus-visible:ring-ring/50 focus-visible:ring-[3px] aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive",
|
||||||
|
|
@ -33,7 +33,7 @@ const buttonVariants = cva(
|
||||||
size: "default",
|
size: "default",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
);
|
||||||
|
|
||||||
function Button({
|
function Button({
|
||||||
className,
|
className,
|
||||||
|
|
@ -43,9 +43,9 @@ function Button({
|
||||||
...props
|
...props
|
||||||
}: React.ComponentProps<"button"> &
|
}: React.ComponentProps<"button"> &
|
||||||
VariantProps<typeof buttonVariants> & {
|
VariantProps<typeof buttonVariants> & {
|
||||||
asChild?: boolean
|
asChild?: boolean;
|
||||||
}) {
|
}) {
|
||||||
const Comp = asChild ? Slot : "button"
|
const Comp = asChild ? Slot : "button";
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Comp
|
<Comp
|
||||||
|
|
@ -53,7 +53,7 @@ function Button({
|
||||||
className={cn(buttonVariants({ variant, size, className }))}
|
className={cn(buttonVariants({ variant, size, className }))}
|
||||||
{...props}
|
{...props}
|
||||||
/>
|
/>
|
||||||
)
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
export { Button, buttonVariants }
|
export { Button, buttonVariants };
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ function Card({ className, ...props }: React.ComponentProps<"div">) {
|
||||||
data-slot="card"
|
data-slot="card"
|
||||||
className={cn(
|
className={cn(
|
||||||
"bg-card text-card-foreground flex flex-col gap-6 rounded-xl border py-6 shadow-sm",
|
"bg-card text-card-foreground flex flex-col gap-6 rounded-xl border py-6 shadow-sm",
|
||||||
className,
|
className
|
||||||
)}
|
)}
|
||||||
{...props}
|
{...props}
|
||||||
/>
|
/>
|
||||||
|
|
@ -21,7 +21,7 @@ function CardHeader({ className, ...props }: React.ComponentProps<"div">) {
|
||||||
data-slot="card-header"
|
data-slot="card-header"
|
||||||
className={cn(
|
className={cn(
|
||||||
"@container/card-header grid auto-rows-min grid-rows-[auto_auto] items-start gap-1.5 px-6 has-data-[slot=card-action]:grid-cols-[1fr_auto] [.border-b]:pb-6",
|
"@container/card-header grid auto-rows-min grid-rows-[auto_auto] items-start gap-1.5 px-6 has-data-[slot=card-action]:grid-cols-[1fr_auto] [.border-b]:pb-6",
|
||||||
className,
|
className
|
||||||
)}
|
)}
|
||||||
{...props}
|
{...props}
|
||||||
/>
|
/>
|
||||||
|
|
@ -54,7 +54,7 @@ function CardAction({ className, ...props }: React.ComponentProps<"div">) {
|
||||||
data-slot="card-action"
|
data-slot="card-action"
|
||||||
className={cn(
|
className={cn(
|
||||||
"col-start-2 row-span-2 row-start-1 self-start justify-self-end",
|
"col-start-2 row-span-2 row-start-1 self-start justify-self-end",
|
||||||
className,
|
className
|
||||||
)}
|
)}
|
||||||
{...props}
|
{...props}
|
||||||
/>
|
/>
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
"use client"
|
"use client";
|
||||||
|
|
||||||
import * as CollapsiblePrimitive from "@radix-ui/react-collapsible"
|
import * as CollapsiblePrimitive from "@radix-ui/react-collapsible";
|
||||||
|
|
||||||
function Collapsible({
|
function Collapsible({
|
||||||
...props
|
...props
|
||||||
}: React.ComponentProps<typeof CollapsiblePrimitive.Root>) {
|
}: React.ComponentProps<typeof CollapsiblePrimitive.Root>) {
|
||||||
return <CollapsiblePrimitive.Root data-slot="collapsible" {...props} />
|
return <CollapsiblePrimitive.Root data-slot="collapsible" {...props} />;
|
||||||
}
|
}
|
||||||
|
|
||||||
function CollapsibleTrigger({
|
function CollapsibleTrigger({
|
||||||
|
|
@ -16,7 +16,7 @@ function CollapsibleTrigger({
|
||||||
data-slot="collapsible-trigger"
|
data-slot="collapsible-trigger"
|
||||||
{...props}
|
{...props}
|
||||||
/>
|
/>
|
||||||
)
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
function CollapsibleContent({
|
function CollapsibleContent({
|
||||||
|
|
@ -27,7 +27,7 @@ function CollapsibleContent({
|
||||||
data-slot="collapsible-content"
|
data-slot="collapsible-content"
|
||||||
{...props}
|
{...props}
|
||||||
/>
|
/>
|
||||||
)
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
export { Collapsible, CollapsibleTrigger, CollapsibleContent }
|
export { Collapsible, CollapsibleTrigger, CollapsibleContent };
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,21 @@
|
||||||
"use client"
|
"use client";
|
||||||
|
|
||||||
import { Check, Copy } from "lucide-react"
|
import { Check, Copy } from "lucide-react";
|
||||||
|
|
||||||
import { cn } from "@/lib/utils"
|
import { cn } from "@/lib/utils";
|
||||||
import { useCopyToClipboard } from "@/hooks/use-copy-to-clipboard"
|
import { useCopyToClipboard } from "@/hooks/use-copy-to-clipboard";
|
||||||
import { Button } from "@/components/ui/button"
|
import { Button } from "@/components/ui/button";
|
||||||
|
|
||||||
type CopyButtonProps = {
|
type CopyButtonProps = {
|
||||||
content: string
|
content: string;
|
||||||
copyMessage?: string
|
copyMessage?: string;
|
||||||
}
|
};
|
||||||
|
|
||||||
export function CopyButton({ content, copyMessage }: CopyButtonProps) {
|
export function CopyButton({ content, copyMessage }: CopyButtonProps) {
|
||||||
const { isCopied, handleCopy } = useCopyToClipboard({
|
const { isCopied, handleCopy } = useCopyToClipboard({
|
||||||
text: content,
|
text: content,
|
||||||
copyMessage,
|
copyMessage,
|
||||||
})
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Button
|
<Button
|
||||||
|
|
@ -40,5 +40,5 @@ export function CopyButton({ content, copyMessage }: CopyButtonProps) {
|
||||||
)}
|
)}
|
||||||
/>
|
/>
|
||||||
</Button>
|
</Button>
|
||||||
)
|
);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ function DropdownMenuContent({
|
||||||
sideOffset={sideOffset}
|
sideOffset={sideOffset}
|
||||||
className={cn(
|
className={cn(
|
||||||
"bg-popover text-popover-foreground data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 z-50 max-h-(--radix-dropdown-menu-content-available-height) min-w-[8rem] origin-(--radix-dropdown-menu-content-transform-origin) overflow-x-hidden overflow-y-auto rounded-md border p-1 shadow-md",
|
"bg-popover text-popover-foreground data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 z-50 max-h-(--radix-dropdown-menu-content-available-height) min-w-[8rem] origin-(--radix-dropdown-menu-content-transform-origin) overflow-x-hidden overflow-y-auto rounded-md border p-1 shadow-md",
|
||||||
className,
|
className
|
||||||
)}
|
)}
|
||||||
{...props}
|
{...props}
|
||||||
/>
|
/>
|
||||||
|
|
@ -75,7 +75,7 @@ function DropdownMenuItem({
|
||||||
data-variant={variant}
|
data-variant={variant}
|
||||||
className={cn(
|
className={cn(
|
||||||
"focus:bg-accent focus:text-accent-foreground data-[variant=destructive]:text-destructive data-[variant=destructive]:focus:bg-destructive/10 dark:data-[variant=destructive]:focus:bg-destructive/20 data-[variant=destructive]:focus:text-destructive data-[variant=destructive]:*:[svg]:!text-destructive [&_svg:not([class*='text-'])]:text-muted-foreground relative flex cursor-default items-center gap-2 rounded-sm px-2 py-1.5 text-sm outline-hidden select-none data-[disabled]:pointer-events-none data-[disabled]:opacity-50 data-[inset]:pl-8 [&_svg]:pointer-events-none [&_svg]:shrink-0 [&_svg:not([class*='size-'])]:size-4",
|
"focus:bg-accent focus:text-accent-foreground data-[variant=destructive]:text-destructive data-[variant=destructive]:focus:bg-destructive/10 dark:data-[variant=destructive]:focus:bg-destructive/20 data-[variant=destructive]:focus:text-destructive data-[variant=destructive]:*:[svg]:!text-destructive [&_svg:not([class*='text-'])]:text-muted-foreground relative flex cursor-default items-center gap-2 rounded-sm px-2 py-1.5 text-sm outline-hidden select-none data-[disabled]:pointer-events-none data-[disabled]:opacity-50 data-[inset]:pl-8 [&_svg]:pointer-events-none [&_svg]:shrink-0 [&_svg:not([class*='size-'])]:size-4",
|
||||||
className,
|
className
|
||||||
)}
|
)}
|
||||||
{...props}
|
{...props}
|
||||||
/>
|
/>
|
||||||
|
|
@ -93,7 +93,7 @@ function DropdownMenuCheckboxItem({
|
||||||
data-slot="dropdown-menu-checkbox-item"
|
data-slot="dropdown-menu-checkbox-item"
|
||||||
className={cn(
|
className={cn(
|
||||||
"focus:bg-accent focus:text-accent-foreground relative flex cursor-default items-center gap-2 rounded-sm py-1.5 pr-2 pl-8 text-sm outline-hidden select-none data-[disabled]:pointer-events-none data-[disabled]:opacity-50 [&_svg]:pointer-events-none [&_svg]:shrink-0 [&_svg:not([class*='size-'])]:size-4",
|
"focus:bg-accent focus:text-accent-foreground relative flex cursor-default items-center gap-2 rounded-sm py-1.5 pr-2 pl-8 text-sm outline-hidden select-none data-[disabled]:pointer-events-none data-[disabled]:opacity-50 [&_svg]:pointer-events-none [&_svg]:shrink-0 [&_svg:not([class*='size-'])]:size-4",
|
||||||
className,
|
className
|
||||||
)}
|
)}
|
||||||
checked={checked}
|
checked={checked}
|
||||||
{...props}
|
{...props}
|
||||||
|
|
@ -129,7 +129,7 @@ function DropdownMenuRadioItem({
|
||||||
data-slot="dropdown-menu-radio-item"
|
data-slot="dropdown-menu-radio-item"
|
||||||
className={cn(
|
className={cn(
|
||||||
"focus:bg-accent focus:text-accent-foreground relative flex cursor-default items-center gap-2 rounded-sm py-1.5 pr-2 pl-8 text-sm outline-hidden select-none data-[disabled]:pointer-events-none data-[disabled]:opacity-50 [&_svg]:pointer-events-none [&_svg]:shrink-0 [&_svg:not([class*='size-'])]:size-4",
|
"focus:bg-accent focus:text-accent-foreground relative flex cursor-default items-center gap-2 rounded-sm py-1.5 pr-2 pl-8 text-sm outline-hidden select-none data-[disabled]:pointer-events-none data-[disabled]:opacity-50 [&_svg]:pointer-events-none [&_svg]:shrink-0 [&_svg:not([class*='size-'])]:size-4",
|
||||||
className,
|
className
|
||||||
)}
|
)}
|
||||||
{...props}
|
{...props}
|
||||||
>
|
>
|
||||||
|
|
@ -156,7 +156,7 @@ function DropdownMenuLabel({
|
||||||
data-inset={inset}
|
data-inset={inset}
|
||||||
className={cn(
|
className={cn(
|
||||||
"px-2 py-1.5 text-sm font-medium data-[inset]:pl-8",
|
"px-2 py-1.5 text-sm font-medium data-[inset]:pl-8",
|
||||||
className,
|
className
|
||||||
)}
|
)}
|
||||||
{...props}
|
{...props}
|
||||||
/>
|
/>
|
||||||
|
|
@ -185,7 +185,7 @@ function DropdownMenuShortcut({
|
||||||
data-slot="dropdown-menu-shortcut"
|
data-slot="dropdown-menu-shortcut"
|
||||||
className={cn(
|
className={cn(
|
||||||
"text-muted-foreground ml-auto text-xs tracking-widest",
|
"text-muted-foreground ml-auto text-xs tracking-widest",
|
||||||
className,
|
className
|
||||||
)}
|
)}
|
||||||
{...props}
|
{...props}
|
||||||
/>
|
/>
|
||||||
|
|
@ -212,7 +212,7 @@ function DropdownMenuSubTrigger({
|
||||||
data-inset={inset}
|
data-inset={inset}
|
||||||
className={cn(
|
className={cn(
|
||||||
"focus:bg-accent focus:text-accent-foreground data-[state=open]:bg-accent data-[state=open]:text-accent-foreground flex cursor-default items-center rounded-sm px-2 py-1.5 text-sm outline-hidden select-none data-[inset]:pl-8",
|
"focus:bg-accent focus:text-accent-foreground data-[state=open]:bg-accent data-[state=open]:text-accent-foreground flex cursor-default items-center rounded-sm px-2 py-1.5 text-sm outline-hidden select-none data-[inset]:pl-8",
|
||||||
className,
|
className
|
||||||
)}
|
)}
|
||||||
{...props}
|
{...props}
|
||||||
>
|
>
|
||||||
|
|
@ -231,7 +231,7 @@ function DropdownMenuSubContent({
|
||||||
data-slot="dropdown-menu-sub-content"
|
data-slot="dropdown-menu-sub-content"
|
||||||
className={cn(
|
className={cn(
|
||||||
"bg-popover text-popover-foreground data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 z-50 min-w-[8rem] origin-(--radix-dropdown-menu-content-transform-origin) overflow-hidden rounded-md border p-1 shadow-lg",
|
"bg-popover text-popover-foreground data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 z-50 min-w-[8rem] origin-(--radix-dropdown-menu-content-transform-origin) overflow-hidden rounded-md border p-1 shadow-lg",
|
||||||
className,
|
className
|
||||||
)}
|
)}
|
||||||
{...props}
|
{...props}
|
||||||
/>
|
/>
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue