mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 20:14:13 +00:00
Merge branch 'main' into chroma
This commit is contained in:
commit
c66ebae9b6
207 changed files with 15490 additions and 7927 deletions
20
.github/actions/run-and-record-tests/action.yml
vendored
20
.github/actions/run-and-record-tests/action.yml
vendored
|
@ -2,9 +2,13 @@ name: 'Run and Record Tests'
|
|||
description: 'Run integration tests and handle recording/artifact upload'
|
||||
|
||||
inputs:
|
||||
test-types:
|
||||
description: 'JSON array of test types to run'
|
||||
test-subdirs:
|
||||
description: 'Comma-separated list of test subdirectories to run'
|
||||
required: true
|
||||
test-pattern:
|
||||
description: 'Regex pattern to pass to pytest -k'
|
||||
required: false
|
||||
default: ''
|
||||
stack-config:
|
||||
description: 'Stack configuration to use'
|
||||
required: true
|
||||
|
@ -35,9 +39,11 @@ runs:
|
|||
./scripts/integration-tests.sh \
|
||||
--stack-config '${{ inputs.stack-config }}' \
|
||||
--provider '${{ inputs.provider }}' \
|
||||
--test-types '${{ inputs.test-types }}' \
|
||||
--test-subdirs '${{ inputs.test-subdirs }}' \
|
||||
--test-pattern '${{ inputs.test-pattern }}' \
|
||||
--inference-mode '${{ inputs.inference-mode }}' \
|
||||
${{ inputs.run-vision-tests == 'true' && '--run-vision-tests' || '' }}
|
||||
${{ inputs.run-vision-tests == 'true' && '--run-vision-tests' || '' }} \
|
||||
| tee pytest-${{ inputs.inference-mode }}.log
|
||||
|
||||
|
||||
- name: Commit and push recordings
|
||||
|
@ -57,10 +63,10 @@ runs:
|
|||
git commit -m "Recordings update from CI"
|
||||
fi
|
||||
|
||||
git fetch origin ${{ github.event.pull_request.head.ref }}
|
||||
git rebase origin/${{ github.event.pull_request.head.ref }}
|
||||
git fetch origin ${{ github.ref_name }}
|
||||
git rebase origin/${{ github.ref_name }}
|
||||
echo "Rebased successfully"
|
||||
git push origin HEAD:${{ github.event.pull_request.head.ref }}
|
||||
git push origin HEAD:${{ github.ref_name }}
|
||||
echo "Pushed successfully"
|
||||
else
|
||||
echo "No recording changes"
|
||||
|
|
2
.github/actions/setup-runner/action.yml
vendored
2
.github/actions/setup-runner/action.yml
vendored
|
@ -28,7 +28,7 @@ runs:
|
|||
# Install llama-stack-client-python based on the client-version input
|
||||
if [ "${{ inputs.client-version }}" = "latest" ]; then
|
||||
echo "Installing latest llama-stack-client-python from main branch"
|
||||
uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main
|
||||
uv pip install git+https://github.com/llamastack/llama-stack-client-python.git@main
|
||||
elif [ "${{ inputs.client-version }}" = "published" ]; then
|
||||
echo "Installing published llama-stack-client-python from PyPI"
|
||||
uv pip install llama-stack-client
|
||||
|
|
30
.github/workflows/integration-tests.yml
vendored
30
.github/workflows/integration-tests.yml
vendored
|
@ -31,6 +31,14 @@ on:
|
|||
description: 'Test against a specific provider'
|
||||
type: string
|
||||
default: 'ollama'
|
||||
test-subdirs:
|
||||
description: 'Comma-separated list of test subdirectories to run'
|
||||
type: string
|
||||
default: ''
|
||||
test-pattern:
|
||||
description: 'Regex pattern to pass to pytest -k'
|
||||
type: string
|
||||
default: ''
|
||||
|
||||
concurrency:
|
||||
# Skip concurrency for pushes to main - each commit should be tested independently
|
||||
|
@ -38,27 +46,8 @@ concurrency:
|
|||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
discover-tests:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
test-types: ${{ steps.generate-test-types.outputs.test-types }}
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Generate test types
|
||||
id: generate-test-types
|
||||
run: |
|
||||
# Get test directories dynamically, excluding non-test directories
|
||||
# NOTE: we are excluding post_training since the tests take too long
|
||||
TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d -printf "%f\n" |
|
||||
grep -Ev "^(__pycache__|fixtures|test_cases|recordings|non_ci|post_training)$" |
|
||||
sort | jq -R -s -c 'split("\n")[:-1]')
|
||||
echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT
|
||||
|
||||
run-replay-mode-tests:
|
||||
needs: discover-tests
|
||||
runs-on: ubuntu-latest
|
||||
name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, vision={4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.run-vision-tests) }}
|
||||
|
||||
|
@ -89,7 +78,8 @@ jobs:
|
|||
- name: Run tests
|
||||
uses: ./.github/actions/run-and-record-tests
|
||||
with:
|
||||
test-types: ${{ needs.discover-tests.outputs.test-types }}
|
||||
test-subdirs: ${{ inputs.test-subdirs }}
|
||||
test-pattern: ${{ inputs.test-pattern }}
|
||||
stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }}
|
||||
provider: ${{ matrix.provider }}
|
||||
inference-mode: 'replay'
|
||||
|
|
|
@ -14,9 +14,11 @@ on:
|
|||
- 'pyproject.toml'
|
||||
- 'requirements.txt'
|
||||
- '.github/workflows/integration-vector-io-tests.yml' # This workflow
|
||||
schedule:
|
||||
- cron: '0 0 * * *' # (test on python 3.13) Daily at 12 AM UTC
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
|
@ -25,7 +27,7 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::weaviate", "remote::qdrant"]
|
||||
python-version: ["3.12", "3.13"]
|
||||
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
|
||||
fail-fast: false # we want to run all tests regardless of failure
|
||||
|
||||
steps:
|
||||
|
@ -164,9 +166,9 @@ jobs:
|
|||
ENABLE_WEAVIATE: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'true' || '' }}
|
||||
WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }}
|
||||
run: |
|
||||
uv run pytest -sv --stack-config="inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \
|
||||
uv run pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \
|
||||
tests/integration/vector_io \
|
||||
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||
--embedding-model inline::sentence-transformers/all-MiniLM-L6-v2
|
||||
|
||||
- name: Check Storage and Memory Available After Tests
|
||||
if: ${{ always() }}
|
||||
|
|
101
.github/workflows/record-integration-tests.yml
vendored
101
.github/workflows/record-integration-tests.yml
vendored
|
@ -1,93 +1,53 @@
|
|||
# This workflow should be run manually when needing to re-record tests. This happens when you have
|
||||
# - added a new test
|
||||
# - or changed an existing test such that a new inference call is made
|
||||
# You should make a PR and then run this workflow on that PR branch. The workflow will re-record the
|
||||
# tests and commit the recordings to the PR branch.
|
||||
name: Integration Tests (Record)
|
||||
|
||||
run-name: Run the integration test suite from tests/integration
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
types: [opened, synchronize, labeled]
|
||||
paths:
|
||||
- 'llama_stack/**'
|
||||
- 'tests/**'
|
||||
- 'uv.lock'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/record-integration-tests.yml' # This workflow
|
||||
- '.github/actions/setup-ollama/action.yml'
|
||||
- '.github/actions/setup-test-environment/action.yml'
|
||||
- '.github/actions/run-and-record-tests/action.yml'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
test-subdirs:
|
||||
description: 'Comma-separated list of test subdirectories to run'
|
||||
type: string
|
||||
default: ''
|
||||
test-provider:
|
||||
description: 'Test against a specific provider'
|
||||
type: string
|
||||
default: 'ollama'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
run-vision-tests:
|
||||
description: 'Whether to run vision tests'
|
||||
type: boolean
|
||||
default: false
|
||||
test-pattern:
|
||||
description: 'Regex pattern to pass to pytest -k'
|
||||
type: string
|
||||
default: ''
|
||||
|
||||
jobs:
|
||||
discover-tests:
|
||||
if: contains(github.event.pull_request.labels.*.name, 're-record-tests') ||
|
||||
contains(github.event.pull_request.labels.*.name, 're-record-vision-tests')
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
test-types: ${{ steps.generate-test-types.outputs.test-types }}
|
||||
matrix-modes: ${{ steps.generate-test-types.outputs.matrix-modes }}
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Generate test types
|
||||
id: generate-test-types
|
||||
run: |
|
||||
# Get test directories dynamically, excluding non-test directories
|
||||
TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d -printf "%f\n" |
|
||||
grep -Ev "^(__pycache__|fixtures|test_cases|recordings|post_training)$" |
|
||||
sort | jq -R -s -c 'split("\n")[:-1]')
|
||||
echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT
|
||||
|
||||
labels=$(gh pr view ${{ github.event.pull_request.number }} --json labels --jq '.labels[].name')
|
||||
echo "labels=$labels"
|
||||
|
||||
modes_array=()
|
||||
if [[ $labels == *"re-record-vision-tests"* ]]; then
|
||||
modes_array+=("vision")
|
||||
fi
|
||||
if [[ $labels == *"re-record-tests"* ]]; then
|
||||
modes_array+=("non-vision")
|
||||
fi
|
||||
|
||||
# Convert to JSON array
|
||||
if [ ${#modes_array[@]} -eq 0 ]; then
|
||||
matrix_modes="[]"
|
||||
else
|
||||
matrix_modes=$(printf '%s\n' "${modes_array[@]}" | jq -R -s -c 'split("\n")[:-1]')
|
||||
fi
|
||||
echo "matrix_modes=$matrix_modes"
|
||||
echo "matrix-modes=$matrix_modes" >> $GITHUB_OUTPUT
|
||||
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
|
||||
record-tests:
|
||||
needs: discover-tests
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
mode: ${{ fromJSON(needs.discover-tests.outputs.matrix-modes) }}
|
||||
|
||||
steps:
|
||||
- name: Echo workflow inputs
|
||||
run: |
|
||||
echo "::group::Workflow Inputs"
|
||||
echo "test-subdirs: ${{ inputs.test-subdirs }}"
|
||||
echo "test-provider: ${{ inputs.test-provider }}"
|
||||
echo "run-vision-tests: ${{ inputs.run-vision-tests }}"
|
||||
echo "test-pattern: ${{ inputs.test-pattern }}"
|
||||
echo "branch: ${{ github.ref_name }}"
|
||||
echo "::endgroup::"
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.ref }}
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup test environment
|
||||
|
@ -96,14 +56,15 @@ jobs:
|
|||
python-version: "3.12" # Use single Python version for recording
|
||||
client-version: "latest"
|
||||
provider: ${{ inputs.test-provider || 'ollama' }}
|
||||
run-vision-tests: ${{ matrix.mode == 'vision' && 'true' || 'false' }}
|
||||
run-vision-tests: ${{ inputs.run-vision-tests }}
|
||||
inference-mode: 'record'
|
||||
|
||||
- name: Run and record tests
|
||||
uses: ./.github/actions/run-and-record-tests
|
||||
with:
|
||||
test-types: ${{ needs.discover-tests.outputs.test-types }}
|
||||
test-pattern: ${{ inputs.test-pattern }}
|
||||
test-subdirs: ${{ inputs.test-subdirs }}
|
||||
stack-config: 'server:ci-tests' # recording must be done with server since more tests are run
|
||||
provider: ${{ inputs.test-provider || 'ollama' }}
|
||||
inference-mode: 'record'
|
||||
run-vision-tests: ${{ matrix.mode == 'vision' && 'true' || 'false' }}
|
||||
run-vision-tests: ${{ inputs.run-vision-tests }}
|
||||
|
|
2
.github/workflows/semantic-pr.yml
vendored
2
.github/workflows/semantic-pr.yml
vendored
|
@ -11,7 +11,7 @@ on:
|
|||
- synchronize
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
|
|
|
@ -2,6 +2,7 @@ exclude: 'build/'
|
|||
|
||||
default_language_version:
|
||||
python: python3.12
|
||||
node: "22"
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
|
@ -145,6 +146,20 @@ repos:
|
|||
pass_filenames: false
|
||||
require_serial: true
|
||||
files: ^.github/workflows/.*$
|
||||
- id: ui-prettier
|
||||
name: Format UI code with Prettier
|
||||
entry: bash -c 'cd llama_stack/ui && npm run format'
|
||||
language: system
|
||||
files: ^llama_stack/ui/.*\.(ts|tsx)$
|
||||
pass_filenames: false
|
||||
require_serial: true
|
||||
- id: ui-eslint
|
||||
name: Lint UI code with ESLint
|
||||
entry: bash -c 'cd llama_stack/ui && npm run lint -- --fix --quiet'
|
||||
language: system
|
||||
files: ^llama_stack/ui/.*\.(ts|tsx)$
|
||||
pass_filenames: false
|
||||
require_serial: true
|
||||
|
||||
ci:
|
||||
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
||||
|
|
169
CONTRIBUTING.md
169
CONTRIBUTING.md
|
@ -1,13 +1,82 @@
|
|||
# Contributing to Llama-Stack
|
||||
# Contributing to Llama Stack
|
||||
We want to make contributing to this project as easy and transparent as
|
||||
possible.
|
||||
|
||||
## Set up your development environment
|
||||
|
||||
We use [uv](https://github.com/astral-sh/uv) to manage python dependencies and virtual environments.
|
||||
You can install `uv` by following this [guide](https://docs.astral.sh/uv/getting-started/installation/).
|
||||
|
||||
You can install the dependencies by running:
|
||||
|
||||
```bash
|
||||
cd llama-stack
|
||||
uv sync --group dev
|
||||
uv pip install -e .
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
```{note}
|
||||
You can use a specific version of Python with `uv` by adding the `--python <version>` flag (e.g. `--python 3.12`).
|
||||
Otherwise, `uv` will automatically select a Python version according to the `requires-python` section of the `pyproject.toml`.
|
||||
For more info, see the [uv docs around Python versions](https://docs.astral.sh/uv/concepts/python-versions/).
|
||||
```
|
||||
|
||||
Note that you can create a dotenv file `.env` that includes necessary environment variables:
|
||||
```
|
||||
LLAMA_STACK_BASE_URL=http://localhost:8321
|
||||
LLAMA_STACK_CLIENT_LOG=debug
|
||||
LLAMA_STACK_PORT=8321
|
||||
LLAMA_STACK_CONFIG=<provider-name>
|
||||
TAVILY_SEARCH_API_KEY=
|
||||
BRAVE_SEARCH_API_KEY=
|
||||
```
|
||||
|
||||
And then use this dotenv file when running client SDK tests via the following:
|
||||
```bash
|
||||
uv run --env-file .env -- pytest -v tests/integration/inference/test_text_inference.py --text-model=meta-llama/Llama-3.1-8B-Instruct
|
||||
```
|
||||
|
||||
### Pre-commit Hooks
|
||||
|
||||
We use [pre-commit](https://pre-commit.com/) to run linting and formatting checks on your code. You can install the pre-commit hooks by running:
|
||||
|
||||
```bash
|
||||
uv run pre-commit install
|
||||
```
|
||||
|
||||
After that, pre-commit hooks will run automatically before each commit.
|
||||
|
||||
Alternatively, if you don't want to install the pre-commit hooks, you can run the checks manually by running:
|
||||
|
||||
```bash
|
||||
uv run pre-commit run --all-files
|
||||
```
|
||||
|
||||
```{caution}
|
||||
Before pushing your changes, make sure that the pre-commit hooks have passed successfully.
|
||||
```
|
||||
|
||||
## Discussions -> Issues -> Pull Requests
|
||||
|
||||
We actively welcome your pull requests. However, please read the following. This is heavily inspired by [Ghostty](https://github.com/ghostty-org/ghostty/blob/main/CONTRIBUTING.md).
|
||||
|
||||
If in doubt, please open a [discussion](https://github.com/meta-llama/llama-stack/discussions); we can always convert that to an issue later.
|
||||
|
||||
### Issues
|
||||
We use GitHub issues to track public bugs. Please ensure your description is
|
||||
clear and has sufficient instructions to be able to reproduce the issue.
|
||||
|
||||
Meta has a [bounty program](http://facebook.com/whitehat/info) for the safe
|
||||
disclosure of security bugs. In those cases, please go through the process
|
||||
outlined on that page and do not file a public issue.
|
||||
|
||||
### Contributor License Agreement ("CLA")
|
||||
In order to accept your pull request, we need you to submit a CLA. You only need
|
||||
to do this once to work on any of Meta's open source projects.
|
||||
|
||||
Complete your CLA here: <https://code.facebook.com/cla>
|
||||
|
||||
**I'd like to contribute!**
|
||||
|
||||
If you are new to the project, start by looking at the issues tagged with "good first issue". If you're interested
|
||||
|
@ -51,93 +120,15 @@ Please avoid picking up too many issues at once. This helps you stay focused and
|
|||
|
||||
Please keep pull requests (PRs) small and focused. If you have a large set of changes, consider splitting them into logically grouped, smaller PRs to facilitate review and testing.
|
||||
|
||||
> [!TIP]
|
||||
> As a general guideline:
|
||||
> - Experienced contributors should try to keep no more than 5 open PRs at a time.
|
||||
> - New contributors are encouraged to have only one open PR at a time until they’re familiar with the codebase and process.
|
||||
|
||||
## Contributor License Agreement ("CLA")
|
||||
In order to accept your pull request, we need you to submit a CLA. You only need
|
||||
to do this once to work on any of Meta's open source projects.
|
||||
|
||||
Complete your CLA here: <https://code.facebook.com/cla>
|
||||
|
||||
## Issues
|
||||
We use GitHub issues to track public bugs. Please ensure your description is
|
||||
clear and has sufficient instructions to be able to reproduce the issue.
|
||||
|
||||
Meta has a [bounty program](http://facebook.com/whitehat/info) for the safe
|
||||
disclosure of security bugs. In those cases, please go through the process
|
||||
outlined on that page and do not file a public issue.
|
||||
|
||||
|
||||
## Set up your development environment
|
||||
|
||||
We use [uv](https://github.com/astral-sh/uv) to manage python dependencies and virtual environments.
|
||||
You can install `uv` by following this [guide](https://docs.astral.sh/uv/getting-started/installation/).
|
||||
|
||||
You can install the dependencies by running:
|
||||
|
||||
```bash
|
||||
cd llama-stack
|
||||
uv sync --group dev
|
||||
uv pip install -e .
|
||||
source .venv/bin/activate
|
||||
```{tip}
|
||||
As a general guideline:
|
||||
- Experienced contributors should try to keep no more than 5 open PRs at a time.
|
||||
- New contributors are encouraged to have only one open PR at a time until they’re familiar with the codebase and process.
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> You can use a specific version of Python with `uv` by adding the `--python <version>` flag (e.g. `--python 3.12`)
|
||||
> Otherwise, `uv` will automatically select a Python version according to the `requires-python` section of the `pyproject.toml`.
|
||||
> For more info, see the [uv docs around Python versions](https://docs.astral.sh/uv/concepts/python-versions/).
|
||||
## Repository guidelines
|
||||
|
||||
Note that you can create a dotenv file `.env` that includes necessary environment variables:
|
||||
```
|
||||
LLAMA_STACK_BASE_URL=http://localhost:8321
|
||||
LLAMA_STACK_CLIENT_LOG=debug
|
||||
LLAMA_STACK_PORT=8321
|
||||
LLAMA_STACK_CONFIG=<provider-name>
|
||||
TAVILY_SEARCH_API_KEY=
|
||||
BRAVE_SEARCH_API_KEY=
|
||||
```
|
||||
|
||||
And then use this dotenv file when running client SDK tests via the following:
|
||||
```bash
|
||||
uv run --env-file .env -- pytest -v tests/integration/inference/test_text_inference.py --text-model=meta-llama/Llama-3.1-8B-Instruct
|
||||
```
|
||||
|
||||
## Pre-commit Hooks
|
||||
|
||||
We use [pre-commit](https://pre-commit.com/) to run linting and formatting checks on your code. You can install the pre-commit hooks by running:
|
||||
|
||||
```bash
|
||||
uv run pre-commit install
|
||||
```
|
||||
|
||||
After that, pre-commit hooks will run automatically before each commit.
|
||||
|
||||
Alternatively, if you don't want to install the pre-commit hooks, you can run the checks manually by running:
|
||||
|
||||
```bash
|
||||
uv run pre-commit run --all-files
|
||||
```
|
||||
|
||||
> [!CAUTION]
|
||||
> Before pushing your changes, make sure that the pre-commit hooks have passed successfully.
|
||||
|
||||
## Running tests
|
||||
|
||||
You can find the Llama Stack testing documentation [here](https://github.com/meta-llama/llama-stack/blob/main/tests/README.md).
|
||||
|
||||
## Adding a new dependency to the project
|
||||
|
||||
To add a new dependency to the project, you can use the `uv` command. For example, to add `foo` to the project, you can run:
|
||||
|
||||
```bash
|
||||
uv add foo
|
||||
uv sync
|
||||
```
|
||||
|
||||
## Coding Style
|
||||
### Coding Style
|
||||
|
||||
* Comments should provide meaningful insights into the code. Avoid filler comments that simply
|
||||
describe the next step, as they create unnecessary clutter, same goes for docstrings.
|
||||
|
@ -159,6 +150,10 @@ uv sync
|
|||
* When possible, use keyword arguments only when calling functions.
|
||||
* Llama Stack utilizes [custom Exception classes](llama_stack/apis/common/errors.py) for certain Resources that should be used where applicable.
|
||||
|
||||
### License
|
||||
By contributing to Llama, you agree that your contributions will be licensed
|
||||
under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
## Common Tasks
|
||||
|
||||
Some tips about common tasks you work on while contributing to Llama Stack:
|
||||
|
@ -211,7 +206,3 @@ uv run ./docs/openapi_generator/run_openapi_generator.sh
|
|||
```
|
||||
|
||||
The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing.
|
||||
|
||||
## License
|
||||
By contributing to Llama, you agree that your contributions will be licensed
|
||||
under the LICENSE file in the root directory of this source tree.
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
# Llama Stack
|
||||
|
||||
<a href="https://trendshift.io/repositories/11824" target="_blank"><img src="https://trendshift.io/api/badge/repositories/11824" alt="meta-llama%2Fllama-stack | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
|
||||
-----
|
||||
[](https://pypi.org/project/llama_stack/)
|
||||
[](https://pypi.org/project/llama-stack/)
|
||||
[](https://github.com/meta-llama/llama-stack/blob/main/LICENSE)
|
||||
|
|
226
docs/_static/llama-stack-spec.html
vendored
226
docs/_static/llama-stack-spec.html
vendored
|
@ -8293,28 +8293,60 @@
|
|||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "null"
|
||||
"properties": {
|
||||
"attributes": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "null"
|
||||
},
|
||||
{
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"type": "number"
|
||||
},
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "array"
|
||||
},
|
||||
{
|
||||
"type": "object"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"type": "number"
|
||||
},
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "array"
|
||||
},
|
||||
{
|
||||
"type": "object"
|
||||
}
|
||||
]
|
||||
}
|
||||
"description": "(Optional) Key-value attributes associated with the file"
|
||||
},
|
||||
"file_id": {
|
||||
"type": "string",
|
||||
"description": "Unique identifier of the file containing the result"
|
||||
},
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "Name of the file containing the result"
|
||||
},
|
||||
"score": {
|
||||
"type": "number",
|
||||
"description": "Relevance score for this search result (between 0 and 1)"
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "Text content of the search result"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"attributes",
|
||||
"file_id",
|
||||
"filename",
|
||||
"score",
|
||||
"text"
|
||||
],
|
||||
"title": "OpenAIResponseOutputMessageFileSearchToolCallResults",
|
||||
"description": "Search results returned by the file search operation."
|
||||
},
|
||||
"description": "(Optional) Search results returned by the file search operation"
|
||||
}
|
||||
|
@ -8515,6 +8547,13 @@
|
|||
"$ref": "#/components/schemas/OpenAIResponseInputTool"
|
||||
}
|
||||
},
|
||||
"include": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "(Optional) Additional fields to include in the response."
|
||||
},
|
||||
"max_infer_iters": {
|
||||
"type": "integer"
|
||||
}
|
||||
|
@ -8782,6 +8821,61 @@
|
|||
"title": "OpenAIResponseOutputMessageMCPListTools",
|
||||
"description": "MCP list tools output message containing available tools from an MCP server."
|
||||
},
|
||||
"OpenAIResponseContentPart": {
|
||||
"oneOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseContentPartOutputText"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseContentPartRefusal"
|
||||
}
|
||||
],
|
||||
"discriminator": {
|
||||
"propertyName": "type",
|
||||
"mapping": {
|
||||
"output_text": "#/components/schemas/OpenAIResponseContentPartOutputText",
|
||||
"refusal": "#/components/schemas/OpenAIResponseContentPartRefusal"
|
||||
}
|
||||
}
|
||||
},
|
||||
"OpenAIResponseContentPartOutputText": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "output_text",
|
||||
"default": "output_text"
|
||||
},
|
||||
"text": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"text"
|
||||
],
|
||||
"title": "OpenAIResponseContentPartOutputText"
|
||||
},
|
||||
"OpenAIResponseContentPartRefusal": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "refusal",
|
||||
"default": "refusal"
|
||||
},
|
||||
"refusal": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"refusal"
|
||||
],
|
||||
"title": "OpenAIResponseContentPartRefusal"
|
||||
},
|
||||
"OpenAIResponseObjectStream": {
|
||||
"oneOf": [
|
||||
{
|
||||
|
@ -8838,6 +8932,12 @@
|
|||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallCompleted"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseContentPartAdded"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseContentPartDone"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted"
|
||||
}
|
||||
|
@ -8863,6 +8963,8 @@
|
|||
"response.mcp_call.in_progress": "#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallInProgress",
|
||||
"response.mcp_call.failed": "#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallFailed",
|
||||
"response.mcp_call.completed": "#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallCompleted",
|
||||
"response.content_part.added": "#/components/schemas/OpenAIResponseObjectStreamResponseContentPartAdded",
|
||||
"response.content_part.done": "#/components/schemas/OpenAIResponseObjectStreamResponseContentPartDone",
|
||||
"response.completed": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted"
|
||||
}
|
||||
}
|
||||
|
@ -8889,6 +8991,80 @@
|
|||
"title": "OpenAIResponseObjectStreamResponseCompleted",
|
||||
"description": "Streaming event indicating a response has been completed."
|
||||
},
|
||||
"OpenAIResponseObjectStreamResponseContentPartAdded": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"response_id": {
|
||||
"type": "string",
|
||||
"description": "Unique identifier of the response containing this content"
|
||||
},
|
||||
"item_id": {
|
||||
"type": "string",
|
||||
"description": "Unique identifier of the output item containing this content part"
|
||||
},
|
||||
"part": {
|
||||
"$ref": "#/components/schemas/OpenAIResponseContentPart",
|
||||
"description": "The content part that was added"
|
||||
},
|
||||
"sequence_number": {
|
||||
"type": "integer",
|
||||
"description": "Sequential number for ordering streaming events"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "response.content_part.added",
|
||||
"default": "response.content_part.added",
|
||||
"description": "Event type identifier, always \"response.content_part.added\""
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"response_id",
|
||||
"item_id",
|
||||
"part",
|
||||
"sequence_number",
|
||||
"type"
|
||||
],
|
||||
"title": "OpenAIResponseObjectStreamResponseContentPartAdded",
|
||||
"description": "Streaming event for when a new content part is added to a response item."
|
||||
},
|
||||
"OpenAIResponseObjectStreamResponseContentPartDone": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"response_id": {
|
||||
"type": "string",
|
||||
"description": "Unique identifier of the response containing this content"
|
||||
},
|
||||
"item_id": {
|
||||
"type": "string",
|
||||
"description": "Unique identifier of the output item containing this content part"
|
||||
},
|
||||
"part": {
|
||||
"$ref": "#/components/schemas/OpenAIResponseContentPart",
|
||||
"description": "The completed content part"
|
||||
},
|
||||
"sequence_number": {
|
||||
"type": "integer",
|
||||
"description": "Sequential number for ordering streaming events"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "response.content_part.done",
|
||||
"default": "response.content_part.done",
|
||||
"description": "Event type identifier, always \"response.content_part.done\""
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"response_id",
|
||||
"item_id",
|
||||
"part",
|
||||
"sequence_number",
|
||||
"type"
|
||||
],
|
||||
"title": "OpenAIResponseObjectStreamResponseContentPartDone",
|
||||
"description": "Streaming event for when a content part is completed."
|
||||
},
|
||||
"OpenAIResponseObjectStreamResponseCreated": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -14591,7 +14767,8 @@
|
|||
"OpenAIFilePurpose": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"assistants"
|
||||
"assistants",
|
||||
"batch"
|
||||
],
|
||||
"title": "OpenAIFilePurpose",
|
||||
"description": "Valid purpose values for OpenAI Files API."
|
||||
|
@ -14668,7 +14845,8 @@
|
|||
"purpose": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"assistants"
|
||||
"assistants",
|
||||
"batch"
|
||||
],
|
||||
"description": "The intended purpose of the file"
|
||||
}
|
||||
|
@ -16530,7 +16708,7 @@
|
|||
"additionalProperties": {
|
||||
"type": "number"
|
||||
},
|
||||
"description": "A list of the categories along with their scores as predicted by model. Required set of categories that need to be in response - violence - violence/graphic - harassment - harassment/threatening - hate - hate/threatening - illicit - illicit/violent - sexual - sexual/minors - self-harm - self-harm/intent - self-harm/instructions"
|
||||
"description": "A list of the categories along with their scores as predicted by model."
|
||||
},
|
||||
"user_message": {
|
||||
"type": "string"
|
||||
|
|
169
docs/_static/llama-stack-spec.yaml
vendored
169
docs/_static/llama-stack-spec.yaml
vendored
|
@ -6021,14 +6021,44 @@ components:
|
|||
type: array
|
||||
items:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
properties:
|
||||
attributes:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
(Optional) Key-value attributes associated with the file
|
||||
file_id:
|
||||
type: string
|
||||
description: >-
|
||||
Unique identifier of the file containing the result
|
||||
filename:
|
||||
type: string
|
||||
description: Name of the file containing the result
|
||||
score:
|
||||
type: number
|
||||
description: >-
|
||||
Relevance score for this search result (between 0 and 1)
|
||||
text:
|
||||
type: string
|
||||
description: Text content of the search result
|
||||
additionalProperties: false
|
||||
required:
|
||||
- attributes
|
||||
- file_id
|
||||
- filename
|
||||
- score
|
||||
- text
|
||||
title: >-
|
||||
OpenAIResponseOutputMessageFileSearchToolCallResults
|
||||
description: >-
|
||||
Search results returned by the file search operation.
|
||||
description: >-
|
||||
(Optional) Search results returned by the file search operation
|
||||
additionalProperties: false
|
||||
|
@ -6188,6 +6218,12 @@ components:
|
|||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/OpenAIResponseInputTool'
|
||||
include:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: >-
|
||||
(Optional) Additional fields to include in the response.
|
||||
max_infer_iters:
|
||||
type: integer
|
||||
additionalProperties: false
|
||||
|
@ -6405,6 +6441,43 @@ components:
|
|||
title: OpenAIResponseOutputMessageMCPListTools
|
||||
description: >-
|
||||
MCP list tools output message containing available tools from an MCP server.
|
||||
OpenAIResponseContentPart:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/OpenAIResponseContentPartOutputText'
|
||||
- $ref: '#/components/schemas/OpenAIResponseContentPartRefusal'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
output_text: '#/components/schemas/OpenAIResponseContentPartOutputText'
|
||||
refusal: '#/components/schemas/OpenAIResponseContentPartRefusal'
|
||||
OpenAIResponseContentPartOutputText:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: output_text
|
||||
default: output_text
|
||||
text:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- text
|
||||
title: OpenAIResponseContentPartOutputText
|
||||
OpenAIResponseContentPartRefusal:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: refusal
|
||||
default: refusal
|
||||
refusal:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- refusal
|
||||
title: OpenAIResponseContentPartRefusal
|
||||
OpenAIResponseObjectStream:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
|
||||
|
@ -6425,6 +6498,8 @@ components:
|
|||
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallInProgress'
|
||||
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallFailed'
|
||||
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallCompleted'
|
||||
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseContentPartAdded'
|
||||
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseContentPartDone'
|
||||
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
|
@ -6447,6 +6522,8 @@ components:
|
|||
response.mcp_call.in_progress: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallInProgress'
|
||||
response.mcp_call.failed: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallFailed'
|
||||
response.mcp_call.completed: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallCompleted'
|
||||
response.content_part.added: '#/components/schemas/OpenAIResponseObjectStreamResponseContentPartAdded'
|
||||
response.content_part.done: '#/components/schemas/OpenAIResponseObjectStreamResponseContentPartDone'
|
||||
response.completed: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted'
|
||||
"OpenAIResponseObjectStreamResponseCompleted":
|
||||
type: object
|
||||
|
@ -6468,6 +6545,76 @@ components:
|
|||
OpenAIResponseObjectStreamResponseCompleted
|
||||
description: >-
|
||||
Streaming event indicating a response has been completed.
|
||||
"OpenAIResponseObjectStreamResponseContentPartAdded":
|
||||
type: object
|
||||
properties:
|
||||
response_id:
|
||||
type: string
|
||||
description: >-
|
||||
Unique identifier of the response containing this content
|
||||
item_id:
|
||||
type: string
|
||||
description: >-
|
||||
Unique identifier of the output item containing this content part
|
||||
part:
|
||||
$ref: '#/components/schemas/OpenAIResponseContentPart'
|
||||
description: The content part that was added
|
||||
sequence_number:
|
||||
type: integer
|
||||
description: >-
|
||||
Sequential number for ordering streaming events
|
||||
type:
|
||||
type: string
|
||||
const: response.content_part.added
|
||||
default: response.content_part.added
|
||||
description: >-
|
||||
Event type identifier, always "response.content_part.added"
|
||||
additionalProperties: false
|
||||
required:
|
||||
- response_id
|
||||
- item_id
|
||||
- part
|
||||
- sequence_number
|
||||
- type
|
||||
title: >-
|
||||
OpenAIResponseObjectStreamResponseContentPartAdded
|
||||
description: >-
|
||||
Streaming event for when a new content part is added to a response item.
|
||||
"OpenAIResponseObjectStreamResponseContentPartDone":
|
||||
type: object
|
||||
properties:
|
||||
response_id:
|
||||
type: string
|
||||
description: >-
|
||||
Unique identifier of the response containing this content
|
||||
item_id:
|
||||
type: string
|
||||
description: >-
|
||||
Unique identifier of the output item containing this content part
|
||||
part:
|
||||
$ref: '#/components/schemas/OpenAIResponseContentPart'
|
||||
description: The completed content part
|
||||
sequence_number:
|
||||
type: integer
|
||||
description: >-
|
||||
Sequential number for ordering streaming events
|
||||
type:
|
||||
type: string
|
||||
const: response.content_part.done
|
||||
default: response.content_part.done
|
||||
description: >-
|
||||
Event type identifier, always "response.content_part.done"
|
||||
additionalProperties: false
|
||||
required:
|
||||
- response_id
|
||||
- item_id
|
||||
- part
|
||||
- sequence_number
|
||||
- type
|
||||
title: >-
|
||||
OpenAIResponseObjectStreamResponseContentPartDone
|
||||
description: >-
|
||||
Streaming event for when a content part is completed.
|
||||
"OpenAIResponseObjectStreamResponseCreated":
|
||||
type: object
|
||||
properties:
|
||||
|
@ -10804,6 +10951,7 @@ components:
|
|||
type: string
|
||||
enum:
|
||||
- assistants
|
||||
- batch
|
||||
title: OpenAIFilePurpose
|
||||
description: >-
|
||||
Valid purpose values for OpenAI Files API.
|
||||
|
@ -10872,6 +11020,7 @@ components:
|
|||
type: string
|
||||
enum:
|
||||
- assistants
|
||||
- batch
|
||||
description: The intended purpose of the file
|
||||
additionalProperties: false
|
||||
required:
|
||||
|
@ -12286,10 +12435,6 @@ components:
|
|||
type: number
|
||||
description: >-
|
||||
A list of the categories along with their scores as predicted by model.
|
||||
Required set of categories that need to be in response - violence - violence/graphic
|
||||
- harassment - harassment/threatening - hate - hate/threatening - illicit
|
||||
- illicit/violent - sexual - sexual/minors - self-harm - self-harm/intent
|
||||
- self-harm/instructions
|
||||
user_message:
|
||||
type: string
|
||||
metadata:
|
||||
|
|
|
@ -111,7 +111,7 @@ name = "llama-stack-api-weather"
|
|||
version = "0.1.0"
|
||||
description = "Weather API for Llama Stack"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = ["llama-stack", "pydantic"]
|
||||
|
||||
[build-system]
|
||||
|
@ -231,7 +231,7 @@ name = "llama-stack-provider-kaze"
|
|||
version = "0.1.0"
|
||||
description = "Kaze weather provider for Llama Stack"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = ["llama-stack", "pydantic", "aiohttp"]
|
||||
|
||||
[build-system]
|
||||
|
|
|
@ -2,7 +2,9 @@
|
|||
|
||||
Llama Stack (LLS) provides two different APIs for building AI applications with tool calling capabilities: the **Agents API** and the **OpenAI Responses API**. While both enable AI systems to use tools, and maintain full conversation history, they serve different use cases and have distinct characteristics.
|
||||
|
||||
> **Note:** For simple and basic inferencing, you may want to use the [Chat Completions API](https://llama-stack.readthedocs.io/en/latest/providers/index.html#chat-completions) directly, before progressing to Agents or Responses API.
|
||||
```{note}
|
||||
For simple and basic inferencing, you may want to use the [Chat Completions API](https://llama-stack.readthedocs.io/en/latest/providers/index.html#chat-completions) directly, before progressing to Agents or Responses API.
|
||||
```
|
||||
|
||||
## Overview
|
||||
|
||||
|
|
|
@ -76,7 +76,9 @@ Features:
|
|||
- Context retrieval with token limits
|
||||
|
||||
|
||||
> **Note:** By default, llama stack run.yaml defines toolgroups for web search, wolfram alpha and rag, that are provided by tavily-search, wolfram-alpha and rag providers.
|
||||
```{note}
|
||||
By default, llama stack run.yaml defines toolgroups for web search, wolfram alpha and rag, that are provided by tavily-search, wolfram-alpha and rag providers.
|
||||
```
|
||||
|
||||
## Model Context Protocol (MCP)
|
||||
|
||||
|
|
|
@ -18,3 +18,4 @@ We are working on adding a few more APIs to complete the application lifecycle.
|
|||
- **Batch Inference**: run inference on a dataset of inputs
|
||||
- **Batch Agents**: run agents on a dataset of inputs
|
||||
- **Synthetic Data Generation**: generate synthetic data for model development
|
||||
- **Batches**: OpenAI-compatible batch management for inference
|
||||
|
|
|
@ -2,24 +2,13 @@
|
|||
```{include} ../../../CONTRIBUTING.md
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
See the [Test Page](testing.md) which describes how to test your changes.
|
||||
```{toctree}
|
||||
:maxdepth: 1
|
||||
:hidden:
|
||||
:caption: Testing
|
||||
|
||||
testing
|
||||
```
|
||||
|
||||
## Adding a New Provider
|
||||
|
||||
See the [Adding a New API Provider Page](new_api_provider.md) which describes how to add new API providers to the Stack.
|
||||
See:
|
||||
- [Adding a New API Provider Page](new_api_provider.md) which describes how to add new API providers to the Stack.
|
||||
- [Vector Database Page](new_vector_database.md) which describes how to add a new vector databases with Llama Stack.
|
||||
- [External Provider Page](../providers/external/index.md) which describes how to add external providers to the Stack.
|
||||
|
||||
See the [Vector Database Page](new_vector_database.md) which describes how to add a new vector databases with Llama Stack.
|
||||
|
||||
See the [External Provider Page](../providers/external/index.md) which describes how to add external providers to the Stack.
|
||||
```{toctree}
|
||||
:maxdepth: 1
|
||||
:hidden:
|
||||
|
@ -27,3 +16,24 @@ See the [External Provider Page](../providers/external/index.md) which describes
|
|||
new_api_provider
|
||||
new_vector_database
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
|
||||
```{include} ../../../tests/README.md
|
||||
```
|
||||
|
||||
## Benchmarking
|
||||
|
||||
```{include} ../../../docs/source/distributions/k8s-benchmark/README.md
|
||||
```
|
||||
|
||||
### Advanced Topics
|
||||
|
||||
For developers who need deeper understanding of the testing system internals:
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 1
|
||||
|
||||
testing/record-replay
|
||||
```
|
||||
|
|
|
@ -1,8 +0,0 @@
|
|||
```{include} ../../../tests/README.md
|
||||
```
|
||||
|
||||
```{include} ../../../tests/unit/README.md
|
||||
```
|
||||
|
||||
```{include} ../../../tests/integration/README.md
|
||||
```
|
234
docs/source/contributing/testing/record-replay.md
Normal file
234
docs/source/contributing/testing/record-replay.md
Normal file
|
@ -0,0 +1,234 @@
|
|||
# Record-Replay System
|
||||
|
||||
Understanding how Llama Stack captures and replays API interactions for testing.
|
||||
|
||||
## Overview
|
||||
|
||||
The record-replay system solves a fundamental challenge in AI testing: how do you test against expensive, non-deterministic APIs without breaking the bank or dealing with flaky tests?
|
||||
|
||||
The solution: intercept API calls, store real responses, and replay them later. This gives you real API behavior without the cost or variability.
|
||||
|
||||
## How It Works
|
||||
|
||||
### Request Hashing
|
||||
|
||||
Every API request gets converted to a deterministic hash for lookup:
|
||||
|
||||
```python
|
||||
def normalize_request(method: str, url: str, headers: dict, body: dict) -> str:
|
||||
normalized = {
|
||||
"method": method.upper(),
|
||||
"endpoint": urlparse(url).path, # Just the path, not full URL
|
||||
"body": body, # Request parameters
|
||||
}
|
||||
return hashlib.sha256(json.dumps(normalized, sort_keys=True).encode()).hexdigest()
|
||||
```
|
||||
|
||||
**Key insight:** The hashing is intentionally precise. Different whitespace, float precision, or parameter order produces different hashes. This prevents subtle bugs from false cache hits.
|
||||
|
||||
```python
|
||||
# These produce DIFFERENT hashes:
|
||||
{"content": "Hello world"}
|
||||
{"content": "Hello world\n"}
|
||||
{"temperature": 0.7}
|
||||
{"temperature": 0.7000001}
|
||||
```
|
||||
|
||||
### Client Interception
|
||||
|
||||
The system patches OpenAI and Ollama client methods to intercept calls before they leave your application. This happens transparently - your test code doesn't change.
|
||||
|
||||
### Storage Architecture
|
||||
|
||||
Recordings use a two-tier storage system optimized for both speed and debuggability:
|
||||
|
||||
```
|
||||
recordings/
|
||||
├── index.sqlite # Fast lookup by request hash
|
||||
└── responses/
|
||||
├── abc123def456.json # Individual response files
|
||||
└── def789ghi012.json
|
||||
```
|
||||
|
||||
**SQLite index** enables O(log n) hash lookups and metadata queries without loading response bodies.
|
||||
|
||||
**JSON files** store complete request/response pairs in human-readable format for debugging.
|
||||
|
||||
## Recording Modes
|
||||
|
||||
### LIVE Mode
|
||||
|
||||
Direct API calls with no recording or replay:
|
||||
|
||||
```python
|
||||
with inference_recording(mode=InferenceMode.LIVE):
|
||||
response = await client.chat.completions.create(...)
|
||||
```
|
||||
|
||||
Use for initial development and debugging against real APIs.
|
||||
|
||||
### RECORD Mode
|
||||
|
||||
Captures API interactions while passing through real responses:
|
||||
|
||||
```python
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir="./recordings"):
|
||||
response = await client.chat.completions.create(...)
|
||||
# Real API call made, response captured AND returned
|
||||
```
|
||||
|
||||
The recording process:
|
||||
1. Request intercepted and hashed
|
||||
2. Real API call executed
|
||||
3. Response captured and serialized
|
||||
4. Recording stored to disk
|
||||
5. Original response returned to caller
|
||||
|
||||
### REPLAY Mode
|
||||
|
||||
Returns stored responses instead of making API calls:
|
||||
|
||||
```python
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir="./recordings"):
|
||||
response = await client.chat.completions.create(...)
|
||||
# No API call made, cached response returned instantly
|
||||
```
|
||||
|
||||
The replay process:
|
||||
1. Request intercepted and hashed
|
||||
2. Hash looked up in SQLite index
|
||||
3. Response loaded from JSON file
|
||||
4. Response deserialized and returned
|
||||
5. Error if no recording found
|
||||
|
||||
## Streaming Support
|
||||
|
||||
Streaming APIs present a unique challenge: how do you capture an async generator?
|
||||
|
||||
### The Problem
|
||||
|
||||
```python
|
||||
# How do you record this?
|
||||
async for chunk in client.chat.completions.create(stream=True):
|
||||
process(chunk)
|
||||
```
|
||||
|
||||
### The Solution
|
||||
|
||||
The system captures all chunks immediately before yielding any:
|
||||
|
||||
```python
|
||||
async def handle_streaming_record(response):
|
||||
# Capture complete stream first
|
||||
chunks = []
|
||||
async for chunk in response:
|
||||
chunks.append(chunk)
|
||||
|
||||
# Store complete recording
|
||||
storage.store_recording(
|
||||
request_hash, request_data, {"body": chunks, "is_streaming": True}
|
||||
)
|
||||
|
||||
# Return generator that replays captured chunks
|
||||
async def replay_stream():
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
return replay_stream()
|
||||
```
|
||||
|
||||
This ensures:
|
||||
- **Complete capture** - The entire stream is saved atomically
|
||||
- **Interface preservation** - The returned object behaves like the original API
|
||||
- **Deterministic replay** - Same chunks in the same order every time
|
||||
|
||||
## Serialization
|
||||
|
||||
API responses contain complex Pydantic objects that need careful serialization:
|
||||
|
||||
```python
|
||||
def _serialize_response(response):
|
||||
if hasattr(response, "model_dump"):
|
||||
# Preserve type information for proper deserialization
|
||||
return {
|
||||
"__type__": f"{response.__class__.__module__}.{response.__class__.__qualname__}",
|
||||
"__data__": response.model_dump(mode="json"),
|
||||
}
|
||||
return response
|
||||
```
|
||||
|
||||
This preserves type safety - when replayed, you get the same Pydantic objects with all their validation and methods.
|
||||
|
||||
## Environment Integration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
Control recording behavior globally:
|
||||
|
||||
```bash
|
||||
export LLAMA_STACK_TEST_INFERENCE_MODE=replay
|
||||
export LLAMA_STACK_TEST_RECORDING_DIR=/path/to/recordings
|
||||
pytest tests/integration/
|
||||
```
|
||||
|
||||
### Pytest Integration
|
||||
|
||||
The system integrates automatically based on environment variables, requiring no changes to test code.
|
||||
|
||||
## Debugging Recordings
|
||||
|
||||
### Inspecting Storage
|
||||
|
||||
```bash
|
||||
# See what's recorded
|
||||
sqlite3 recordings/index.sqlite "SELECT endpoint, model, timestamp FROM recordings LIMIT 10;"
|
||||
|
||||
# View specific response
|
||||
cat recordings/responses/abc123def456.json | jq '.response.body'
|
||||
|
||||
# Find recordings by endpoint
|
||||
sqlite3 recordings/index.sqlite "SELECT * FROM recordings WHERE endpoint='/v1/chat/completions';"
|
||||
```
|
||||
|
||||
### Common Issues
|
||||
|
||||
**Hash mismatches:** Request parameters changed slightly between record and replay
|
||||
```bash
|
||||
# Compare request details
|
||||
cat recordings/responses/abc123.json | jq '.request'
|
||||
```
|
||||
|
||||
**Serialization errors:** Response types changed between versions
|
||||
```bash
|
||||
# Re-record with updated types
|
||||
rm recordings/responses/failing_hash.json
|
||||
LLAMA_STACK_TEST_INFERENCE_MODE=record pytest test_failing.py
|
||||
```
|
||||
|
||||
**Missing recordings:** New test or changed parameters
|
||||
```bash
|
||||
# Record the missing interaction
|
||||
LLAMA_STACK_TEST_INFERENCE_MODE=record pytest test_new.py
|
||||
```
|
||||
|
||||
## Design Decisions
|
||||
|
||||
### Why Not Mocks?
|
||||
|
||||
Traditional mocking breaks down with AI APIs because:
|
||||
- Response structures are complex and evolve frequently
|
||||
- Streaming behavior is hard to mock correctly
|
||||
- Edge cases in real APIs get missed
|
||||
- Mocks become brittle maintenance burdens
|
||||
|
||||
### Why Precise Hashing?
|
||||
|
||||
Loose hashing (normalizing whitespace, rounding floats) seems convenient but hides bugs. If a test changes slightly, you want to know about it rather than accidentally getting the wrong cached response.
|
||||
|
||||
### Why JSON + SQLite?
|
||||
|
||||
- **JSON** - Human readable, diff-friendly, easy to inspect and modify
|
||||
- **SQLite** - Fast indexed lookups without loading response bodies
|
||||
- **Hybrid** - Best of both worlds for different use cases
|
||||
|
||||
This system provides reliable, fast testing against real AI APIs while maintaining the ability to debug issues when they arise.
|
156
docs/source/distributions/k8s-benchmark/README.md
Normal file
156
docs/source/distributions/k8s-benchmark/README.md
Normal file
|
@ -0,0 +1,156 @@
|
|||
# Llama Stack Benchmark Suite on Kubernetes
|
||||
|
||||
## Motivation
|
||||
|
||||
Performance benchmarking is critical for understanding the overhead and characteristics of the Llama Stack abstraction layer compared to direct inference engines like vLLM.
|
||||
|
||||
### Why This Benchmark Suite Exists
|
||||
|
||||
**Performance Validation**: The Llama Stack provides a unified API layer across multiple inference providers, but this abstraction introduces potential overhead. This benchmark suite quantifies the performance impact by comparing:
|
||||
- Llama Stack inference (with vLLM backend)
|
||||
- Direct vLLM inference calls
|
||||
- Both under identical Kubernetes deployment conditions
|
||||
|
||||
**Production Readiness Assessment**: Real-world deployments require understanding performance characteristics under load. This suite simulates concurrent user scenarios with configurable parameters (duration, concurrency, request patterns) to validate production readiness.
|
||||
|
||||
**Regression Detection (TODO)**: As the Llama Stack evolves, this benchmark provides automated regression detection for performance changes. CI/CD pipelines can leverage these benchmarks to catch performance degradations before production deployments.
|
||||
|
||||
**Resource Planning**: By measuring throughput, latency percentiles, and resource utilization patterns, teams can make informed decisions about:
|
||||
- Kubernetes resource allocation (CPU, memory, GPU)
|
||||
- Auto-scaling configurations
|
||||
- Cost optimization strategies
|
||||
|
||||
### Key Metrics Captured
|
||||
|
||||
The benchmark suite measures critical performance indicators:
|
||||
- **Throughput**: Requests per second under sustained load
|
||||
- **Latency Distribution**: P50, P95, P99 response times
|
||||
- **Time to First Token (TTFT)**: Critical for streaming applications
|
||||
- **Error Rates**: Request failures and timeout analysis
|
||||
|
||||
This data enables data-driven architectural decisions and performance optimization efforts.
|
||||
|
||||
## Setup
|
||||
|
||||
**1. Deploy base k8s infrastructure:**
|
||||
```bash
|
||||
cd ../k8s
|
||||
./apply.sh
|
||||
```
|
||||
|
||||
**2. Deploy benchmark components:**
|
||||
```bash
|
||||
cd ../k8s-benchmark
|
||||
./apply.sh
|
||||
```
|
||||
|
||||
**3. Verify deployment:**
|
||||
```bash
|
||||
kubectl get pods
|
||||
# Should see: llama-stack-benchmark-server, vllm-server, etc.
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Benchmarks
|
||||
|
||||
**Benchmark Llama Stack (default):**
|
||||
```bash
|
||||
cd docs/source/distributions/k8s-benchmark/
|
||||
./run-benchmark.sh
|
||||
```
|
||||
|
||||
**Benchmark vLLM direct:**
|
||||
```bash
|
||||
./run-benchmark.sh --target vllm
|
||||
```
|
||||
|
||||
### Custom Configuration
|
||||
|
||||
**Extended benchmark with high concurrency:**
|
||||
```bash
|
||||
./run-benchmark.sh --target vllm --duration 120 --concurrent 20
|
||||
```
|
||||
|
||||
**Short test run:**
|
||||
```bash
|
||||
./run-benchmark.sh --target stack --duration 30 --concurrent 5
|
||||
```
|
||||
|
||||
## Command Reference
|
||||
|
||||
### run-benchmark.sh Options
|
||||
|
||||
```bash
|
||||
./run-benchmark.sh [options]
|
||||
|
||||
Options:
|
||||
-t, --target <stack|vllm> Target to benchmark (default: stack)
|
||||
-d, --duration <seconds> Duration in seconds (default: 60)
|
||||
-c, --concurrent <users> Number of concurrent users (default: 10)
|
||||
-h, --help Show help message
|
||||
|
||||
Examples:
|
||||
./run-benchmark.sh --target vllm # Benchmark vLLM direct
|
||||
./run-benchmark.sh --target stack # Benchmark Llama Stack
|
||||
./run-benchmark.sh -t vllm -d 120 -c 20 # vLLM with 120s, 20 users
|
||||
```
|
||||
|
||||
## Local Testing
|
||||
|
||||
### Running Benchmark Locally
|
||||
|
||||
For local development without Kubernetes:
|
||||
|
||||
**1. Start OpenAI mock server:**
|
||||
```bash
|
||||
uv run python openai-mock-server.py --port 8080
|
||||
```
|
||||
|
||||
**2. Run benchmark against mock server:**
|
||||
```bash
|
||||
uv run python benchmark.py \
|
||||
--base-url http://localhost:8080/v1 \
|
||||
--model mock-inference \
|
||||
--duration 30 \
|
||||
--concurrent 5
|
||||
```
|
||||
|
||||
**3. Test against local vLLM server:**
|
||||
```bash
|
||||
# If you have vLLM running locally on port 8000
|
||||
uv run python benchmark.py \
|
||||
--base-url http://localhost:8000/v1 \
|
||||
--model meta-llama/Llama-3.2-3B-Instruct \
|
||||
--duration 30 \
|
||||
--concurrent 5
|
||||
```
|
||||
|
||||
**4. Profile the running server:**
|
||||
```bash
|
||||
./profile_running_server.sh
|
||||
```
|
||||
|
||||
|
||||
|
||||
### OpenAI Mock Server
|
||||
|
||||
The `openai-mock-server.py` provides:
|
||||
- **OpenAI-compatible API** for testing without real models
|
||||
- **Configurable streaming delay** via `STREAM_DELAY_SECONDS` env var
|
||||
- **Consistent responses** for reproducible benchmarks
|
||||
- **Lightweight testing** without GPU requirements
|
||||
|
||||
**Mock server usage:**
|
||||
```bash
|
||||
uv run python openai-mock-server.py --port 8080
|
||||
```
|
||||
|
||||
The mock server is also deployed in k8s as `openai-mock-service:8080` and can be used by changing the Llama Stack configuration to use the `mock-vllm-inference` provider.
|
||||
|
||||
## Files in this Directory
|
||||
|
||||
- `benchmark.py` - Core benchmark script with async streaming support
|
||||
- `run-benchmark.sh` - Main script with target selection and configuration
|
||||
- `openai-mock-server.py` - Mock OpenAI API server for local testing
|
||||
- `README.md` - This documentation file
|
36
docs/source/distributions/k8s-benchmark/apply.sh
Executable file
36
docs/source/distributions/k8s-benchmark/apply.sh
Executable file
|
@ -0,0 +1,36 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Deploys the benchmark-specific components on top of the base k8s deployment (../k8s/apply.sh).
|
||||
|
||||
export STREAM_DELAY_SECONDS=0.005
|
||||
|
||||
export POSTGRES_USER=llamastack
|
||||
export POSTGRES_DB=llamastack
|
||||
export POSTGRES_PASSWORD=llamastack
|
||||
|
||||
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||
|
||||
export MOCK_INFERENCE_MODEL=mock-inference
|
||||
|
||||
export MOCK_INFERENCE_URL=openai-mock-service:8080
|
||||
|
||||
export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL
|
||||
|
||||
set -euo pipefail
|
||||
set -x
|
||||
|
||||
# Deploy benchmark-specific components
|
||||
kubectl create configmap llama-stack-config --from-file=stack_run_config.yaml \
|
||||
--dry-run=client -o yaml > stack-configmap.yaml
|
||||
|
||||
kubectl apply --validate=false -f stack-configmap.yaml
|
||||
|
||||
# Deploy our custom llama stack server (overriding the base one)
|
||||
envsubst < stack-k8s.yaml.template | kubectl apply --validate=false -f -
|
268
docs/source/distributions/k8s-benchmark/benchmark.py
Normal file
268
docs/source/distributions/k8s-benchmark/benchmark.py
Normal file
|
@ -0,0 +1,268 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Simple benchmark script for Llama Stack with OpenAI API compatibility.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
import statistics
|
||||
import time
|
||||
from typing import Tuple
|
||||
import aiohttp
|
||||
|
||||
|
||||
class BenchmarkStats:
|
||||
def __init__(self):
|
||||
self.response_times = []
|
||||
self.ttft_times = []
|
||||
self.chunks_received = []
|
||||
self.errors = []
|
||||
self.success_count = 0
|
||||
self.total_requests = 0
|
||||
self.concurrent_users = 0
|
||||
self.start_time = None
|
||||
self.end_time = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def add_result(self, response_time: float, chunks: int, ttft: float = None, error: str = None):
|
||||
async with self._lock:
|
||||
self.total_requests += 1
|
||||
if error:
|
||||
self.errors.append(error)
|
||||
else:
|
||||
self.success_count += 1
|
||||
self.response_times.append(response_time)
|
||||
self.chunks_received.append(chunks)
|
||||
if ttft is not None:
|
||||
self.ttft_times.append(ttft)
|
||||
|
||||
def print_summary(self):
|
||||
if not self.response_times:
|
||||
print("No successful requests to report")
|
||||
if self.errors:
|
||||
print(f"Total errors: {len(self.errors)}")
|
||||
print("First 5 errors:")
|
||||
for error in self.errors[:5]:
|
||||
print(f" {error}")
|
||||
return
|
||||
|
||||
total_time = self.end_time - self.start_time
|
||||
success_rate = (self.success_count / self.total_requests) * 100
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"BENCHMARK RESULTS")
|
||||
print(f"{'='*60}")
|
||||
print(f"Total time: {total_time:.2f}s")
|
||||
print(f"Concurrent users: {self.concurrent_users}")
|
||||
print(f"Total requests: {self.total_requests}")
|
||||
print(f"Successful requests: {self.success_count}")
|
||||
print(f"Failed requests: {len(self.errors)}")
|
||||
print(f"Success rate: {success_rate:.1f}%")
|
||||
print(f"Requests per second: {self.success_count / total_time:.2f}")
|
||||
|
||||
print(f"\nResponse Time Statistics:")
|
||||
print(f" Mean: {statistics.mean(self.response_times):.3f}s")
|
||||
print(f" Median: {statistics.median(self.response_times):.3f}s")
|
||||
print(f" Min: {min(self.response_times):.3f}s")
|
||||
print(f" Max: {max(self.response_times):.3f}s")
|
||||
|
||||
if len(self.response_times) > 1:
|
||||
print(f" Std Dev: {statistics.stdev(self.response_times):.3f}s")
|
||||
|
||||
percentiles = [50, 90, 95, 99]
|
||||
sorted_times = sorted(self.response_times)
|
||||
print(f"\nPercentiles:")
|
||||
for p in percentiles:
|
||||
idx = int(len(sorted_times) * p / 100) - 1
|
||||
idx = max(0, min(idx, len(sorted_times) - 1))
|
||||
print(f" P{p}: {sorted_times[idx]:.3f}s")
|
||||
|
||||
if self.ttft_times:
|
||||
print(f"\nTime to First Token (TTFT) Statistics:")
|
||||
print(f" Mean: {statistics.mean(self.ttft_times):.3f}s")
|
||||
print(f" Median: {statistics.median(self.ttft_times):.3f}s")
|
||||
print(f" Min: {min(self.ttft_times):.3f}s")
|
||||
print(f" Max: {max(self.ttft_times):.3f}s")
|
||||
|
||||
if len(self.ttft_times) > 1:
|
||||
print(f" Std Dev: {statistics.stdev(self.ttft_times):.3f}s")
|
||||
|
||||
sorted_ttft = sorted(self.ttft_times)
|
||||
print(f"\nTTFT Percentiles:")
|
||||
for p in percentiles:
|
||||
idx = int(len(sorted_ttft) * p / 100) - 1
|
||||
idx = max(0, min(idx, len(sorted_ttft) - 1))
|
||||
print(f" P{p}: {sorted_ttft[idx]:.3f}s")
|
||||
|
||||
if self.chunks_received:
|
||||
print(f"\nStreaming Statistics:")
|
||||
print(f" Mean chunks per response: {statistics.mean(self.chunks_received):.1f}")
|
||||
print(f" Total chunks received: {sum(self.chunks_received)}")
|
||||
|
||||
if self.errors:
|
||||
print(f"\nErrors (showing first 5):")
|
||||
for error in self.errors[:5]:
|
||||
print(f" {error}")
|
||||
|
||||
|
||||
class LlamaStackBenchmark:
|
||||
def __init__(self, base_url: str, model_id: str):
|
||||
self.base_url = base_url.rstrip('/')
|
||||
self.model_id = model_id
|
||||
self.headers = {"Content-Type": "application/json"}
|
||||
self.test_messages = [
|
||||
[{"role": "user", "content": "Hi"}],
|
||||
[{"role": "user", "content": "What is the capital of France?"}],
|
||||
[{"role": "user", "content": "Explain quantum physics in simple terms."}],
|
||||
[{"role": "user", "content": "Write a short story about a robot learning to paint."}],
|
||||
[
|
||||
{"role": "user", "content": "What is machine learning?"},
|
||||
{"role": "assistant", "content": "Machine learning is a subset of AI..."},
|
||||
{"role": "user", "content": "Can you give me a practical example?"}
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
async def make_async_streaming_request(self) -> Tuple[float, int, float | None, str | None]:
|
||||
"""Make a single async streaming chat completion request."""
|
||||
messages = random.choice(self.test_messages)
|
||||
payload = {
|
||||
"model": self.model_id,
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
"max_tokens": 100
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
chunks_received = 0
|
||||
ttft = None
|
||||
error = None
|
||||
|
||||
session = aiohttp.ClientSession()
|
||||
|
||||
try:
|
||||
async with session.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
async for line in response.content:
|
||||
if line:
|
||||
line_str = line.decode('utf-8').strip()
|
||||
if line_str.startswith('data: '):
|
||||
chunks_received += 1
|
||||
if ttft is None:
|
||||
ttft = time.time() - start_time
|
||||
if line_str == 'data: [DONE]':
|
||||
break
|
||||
|
||||
if chunks_received == 0:
|
||||
error = "No streaming chunks received"
|
||||
else:
|
||||
text = await response.text()
|
||||
error = f"HTTP {response.status}: {text[:100]}"
|
||||
|
||||
except Exception as e:
|
||||
error = f"Request error: {str(e)}"
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
response_time = time.time() - start_time
|
||||
return response_time, chunks_received, ttft, error
|
||||
|
||||
|
||||
async def run_benchmark(self, duration: int, concurrent_users: int) -> BenchmarkStats:
|
||||
"""Run benchmark using async requests for specified duration."""
|
||||
stats = BenchmarkStats()
|
||||
stats.concurrent_users = concurrent_users
|
||||
stats.start_time = time.time()
|
||||
|
||||
print(f"Starting benchmark: {duration}s duration, {concurrent_users} concurrent users")
|
||||
print(f"Target URL: {self.base_url}/chat/completions")
|
||||
print(f"Model: {self.model_id}")
|
||||
|
||||
connector = aiohttp.TCPConnector(limit=concurrent_users)
|
||||
async with aiohttp.ClientSession(connector=connector) as session:
|
||||
|
||||
async def worker(worker_id: int):
|
||||
"""Worker that sends requests sequentially until canceled."""
|
||||
request_count = 0
|
||||
while True:
|
||||
try:
|
||||
response_time, chunks, ttft, error = await self.make_async_streaming_request()
|
||||
await stats.add_result(response_time, chunks, ttft, error)
|
||||
request_count += 1
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
await stats.add_result(0, 0, None, f"Worker {worker_id} error: {str(e)}")
|
||||
|
||||
# Progress reporting task
|
||||
async def progress_reporter():
|
||||
last_report_time = time.time()
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(1) # Report every second
|
||||
if time.time() >= last_report_time + 10: # Report every 10 seconds
|
||||
elapsed = time.time() - stats.start_time
|
||||
print(f"Completed: {stats.total_requests} requests in {elapsed:.1f}s")
|
||||
last_report_time = time.time()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
# Spawn concurrent workers
|
||||
tasks = [asyncio.create_task(worker(i)) for i in range(concurrent_users)]
|
||||
progress_task = asyncio.create_task(progress_reporter())
|
||||
tasks.append(progress_task)
|
||||
|
||||
# Wait for duration then cancel all tasks
|
||||
await asyncio.sleep(duration)
|
||||
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait for all tasks to complete
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
stats.end_time = time.time()
|
||||
return stats
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Llama Stack Benchmark Tool")
|
||||
parser.add_argument("--base-url", default=os.getenv("BENCHMARK_BASE_URL", "http://localhost:8000/v1/openai/v1"),
|
||||
help="Base URL for the API (default: http://localhost:8000/v1/openai/v1)")
|
||||
parser.add_argument("--model", default=os.getenv("INFERENCE_MODEL", "test-model"),
|
||||
help="Model ID to use for requests")
|
||||
parser.add_argument("--duration", type=int, default=60,
|
||||
help="Duration in seconds to run benchmark (default: 60)")
|
||||
parser.add_argument("--concurrent", type=int, default=10,
|
||||
help="Number of concurrent users (default: 10)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
benchmark = LlamaStackBenchmark(args.base_url, args.model)
|
||||
|
||||
try:
|
||||
stats = asyncio.run(benchmark.run_benchmark(args.duration, args.concurrent))
|
||||
stats.print_summary()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nBenchmark interrupted by user")
|
||||
except Exception as e:
|
||||
print(f"Benchmark failed: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
190
docs/source/distributions/k8s-benchmark/openai-mock-server.py
Executable file
190
docs/source/distributions/k8s-benchmark/openai-mock-server.py
Executable file
|
@ -0,0 +1,190 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
OpenAI-compatible mock server that returns:
|
||||
- Hardcoded /models response for consistent validation
|
||||
- Valid OpenAI-formatted chat completion responses with dynamic content
|
||||
"""
|
||||
|
||||
from flask import Flask, request, jsonify, Response
|
||||
import time
|
||||
import random
|
||||
import uuid
|
||||
import json
|
||||
import argparse
|
||||
import os
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
# Models from environment variables
|
||||
def get_models():
|
||||
models_str = os.getenv("MOCK_MODELS", "meta-llama/Llama-3.2-3B-Instruct")
|
||||
model_ids = [m.strip() for m in models_str.split(",") if m.strip()]
|
||||
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": model_id,
|
||||
"object": "model",
|
||||
"created": 1234567890,
|
||||
"owned_by": "vllm"
|
||||
}
|
||||
for model_id in model_ids
|
||||
]
|
||||
}
|
||||
|
||||
def generate_random_text(length=50):
|
||||
"""Generate random but coherent text for responses."""
|
||||
words = [
|
||||
"Hello", "there", "I'm", "an", "AI", "assistant", "ready", "to", "help", "you",
|
||||
"with", "your", "questions", "and", "tasks", "today", "Let", "me","know", "what",
|
||||
"you'd", "like", "to", "discuss", "or", "explore", "together", "I", "can", "assist",
|
||||
"with", "various", "topics", "including", "coding", "writing", "analysis", "and", "more"
|
||||
]
|
||||
return " ".join(random.choices(words, k=length))
|
||||
|
||||
@app.route('/v1/models', methods=['GET'])
|
||||
def list_models():
|
||||
models = get_models()
|
||||
print(f"[MOCK] Returning models: {[m['id'] for m in models['data']]}")
|
||||
return jsonify(models)
|
||||
|
||||
@app.route('/v1/chat/completions', methods=['POST'])
|
||||
def chat_completions():
|
||||
"""Return OpenAI-formatted chat completion responses."""
|
||||
data = request.get_json()
|
||||
default_model = get_models()['data'][0]['id']
|
||||
model = data.get('model', default_model)
|
||||
messages = data.get('messages', [])
|
||||
stream = data.get('stream', False)
|
||||
|
||||
print(f"[MOCK] Chat completion request - model: {model}, stream: {stream}")
|
||||
|
||||
if stream:
|
||||
return handle_streaming_completion(model, messages)
|
||||
else:
|
||||
return handle_non_streaming_completion(model, messages)
|
||||
|
||||
def handle_non_streaming_completion(model, messages):
|
||||
response_text = generate_random_text(random.randint(20, 80))
|
||||
|
||||
# Calculate realistic token counts
|
||||
prompt_tokens = sum(len(str(msg.get('content', '')).split()) for msg in messages)
|
||||
completion_tokens = len(response_text.split())
|
||||
|
||||
response = {
|
||||
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response_text
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens
|
||||
}
|
||||
}
|
||||
|
||||
return jsonify(response)
|
||||
|
||||
def handle_streaming_completion(model, messages):
|
||||
def generate_stream():
|
||||
# Generate response text
|
||||
full_response = generate_random_text(random.randint(30, 100))
|
||||
words = full_response.split()
|
||||
|
||||
# Send initial chunk
|
||||
initial_chunk = {
|
||||
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": ""}
|
||||
}
|
||||
]
|
||||
}
|
||||
yield f"data: {json.dumps(initial_chunk)}\n\n"
|
||||
|
||||
# Send word by word
|
||||
for i, word in enumerate(words):
|
||||
chunk = {
|
||||
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": f"{word} " if i < len(words) - 1 else word}
|
||||
}
|
||||
]
|
||||
}
|
||||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
# Configurable delay to simulate realistic streaming
|
||||
stream_delay = float(os.getenv("STREAM_DELAY_SECONDS", "0.005"))
|
||||
time.sleep(stream_delay)
|
||||
|
||||
# Send final chunk
|
||||
final_chunk = {
|
||||
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": ""},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
]
|
||||
}
|
||||
yield f"data: {json.dumps(final_chunk)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return Response(
|
||||
generate_stream(),
|
||||
mimetype='text/event-stream',
|
||||
headers={
|
||||
'Cache-Control': 'no-cache',
|
||||
'Connection': 'keep-alive',
|
||||
'Access-Control-Allow-Origin': '*',
|
||||
}
|
||||
)
|
||||
|
||||
@app.route('/health', methods=['GET'])
|
||||
def health():
|
||||
return jsonify({"status": "healthy", "type": "openai-mock"})
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='OpenAI-compatible mock server')
|
||||
parser.add_argument('--port', type=int, default=8081,
|
||||
help='Port to run the server on (default: 8081)')
|
||||
args = parser.parse_args()
|
||||
|
||||
port = args.port
|
||||
|
||||
models = get_models()
|
||||
print("Starting OpenAI-compatible mock server...")
|
||||
print(f"- /models endpoint with: {[m['id'] for m in models['data']]}")
|
||||
print("- OpenAI-formatted chat/completion responses with dynamic content")
|
||||
print("- Streaming support with valid SSE format")
|
||||
print(f"- Listening on: http://0.0.0.0:{port}")
|
||||
app.run(host='0.0.0.0', port=port, debug=False)
|
52
docs/source/distributions/k8s-benchmark/profile_running_server.sh
Executable file
52
docs/source/distributions/k8s-benchmark/profile_running_server.sh
Executable file
|
@ -0,0 +1,52 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Script to profile an already running Llama Stack server
|
||||
# Usage: ./profile_running_server.sh [duration_seconds] [output_file]
|
||||
|
||||
DURATION=${1:-60} # Default 60 seconds
|
||||
OUTPUT_FILE=${2:-"llama_stack_profile"} # Default output file
|
||||
|
||||
echo "Looking for running Llama Stack server..."
|
||||
|
||||
# Find the server PID
|
||||
SERVER_PID=$(ps aux | grep "llama_stack.core.server.server" | grep -v grep | awk '{print $2}' | head -1)
|
||||
|
||||
|
||||
if [ -z "$SERVER_PID" ]; then
|
||||
echo "Error: No running Llama Stack server found"
|
||||
echo "Please start your server first with:"
|
||||
echo "LLAMA_STACK_LOGGING=\"all=ERROR\" MOCK_INFERENCE_URL=http://localhost:8080 SAFETY_MODEL=llama-guard3:1b uv run --with llama-stack python -m llama_stack.core.server.server docs/source/distributions/k8s-benchmark/stack_run_config.yaml"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Found Llama Stack server with PID: $SERVER_PID"
|
||||
|
||||
# Start py-spy profiling
|
||||
echo "Starting py-spy profiling for ${DURATION} seconds..."
|
||||
echo "Output will be saved to: ${OUTPUT_FILE}.svg"
|
||||
echo ""
|
||||
echo "You can now run your load test..."
|
||||
echo ""
|
||||
|
||||
# Get the full path to py-spy
|
||||
PYSPY_PATH=$(which py-spy)
|
||||
|
||||
# Check if running as root, if not, use sudo
|
||||
if [ "$EUID" -ne 0 ]; then
|
||||
echo "py-spy requires root permissions on macOS. Running with sudo..."
|
||||
sudo "$PYSPY_PATH" record -o "${OUTPUT_FILE}.svg" -d ${DURATION} -p $SERVER_PID
|
||||
else
|
||||
"$PYSPY_PATH" record -o "${OUTPUT_FILE}.svg" -d ${DURATION} -p $SERVER_PID
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "Profiling completed! Results saved to: ${OUTPUT_FILE}.svg"
|
||||
echo ""
|
||||
echo "To view the flame graph:"
|
||||
echo "open ${OUTPUT_FILE}.svg"
|
148
docs/source/distributions/k8s-benchmark/run-benchmark.sh
Executable file
148
docs/source/distributions/k8s-benchmark/run-benchmark.sh
Executable file
|
@ -0,0 +1,148 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Default values
|
||||
TARGET="stack"
|
||||
DURATION=60
|
||||
CONCURRENT=10
|
||||
|
||||
# Parse command line arguments
|
||||
usage() {
|
||||
echo "Usage: $0 [options]"
|
||||
echo "Options:"
|
||||
echo " -t, --target <stack|vllm> Target to benchmark (default: stack)"
|
||||
echo " -d, --duration <seconds> Duration in seconds (default: 60)"
|
||||
echo " -c, --concurrent <users> Number of concurrent users (default: 10)"
|
||||
echo " -h, --help Show this help message"
|
||||
echo ""
|
||||
echo "Examples:"
|
||||
echo " $0 --target vllm # Benchmark vLLM direct"
|
||||
echo " $0 --target stack # Benchmark Llama Stack (default)"
|
||||
echo " $0 -t vllm -d 120 -c 20 # vLLM with 120s duration, 20 users"
|
||||
}
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
-t|--target)
|
||||
TARGET="$2"
|
||||
shift 2
|
||||
;;
|
||||
-d|--duration)
|
||||
DURATION="$2"
|
||||
shift 2
|
||||
;;
|
||||
-c|--concurrent)
|
||||
CONCURRENT="$2"
|
||||
shift 2
|
||||
;;
|
||||
-h|--help)
|
||||
usage
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1"
|
||||
usage
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Validate target
|
||||
if [[ "$TARGET" != "stack" && "$TARGET" != "vllm" ]]; then
|
||||
echo "Error: Target must be 'stack' or 'vllm'"
|
||||
usage
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Set configuration based on target
|
||||
if [[ "$TARGET" == "vllm" ]]; then
|
||||
BASE_URL="http://vllm-server:8000/v1"
|
||||
JOB_NAME="vllm-benchmark-job"
|
||||
echo "Benchmarking vLLM direct..."
|
||||
else
|
||||
BASE_URL="http://llama-stack-benchmark-service:8323/v1/openai/v1"
|
||||
JOB_NAME="stack-benchmark-job"
|
||||
echo "Benchmarking Llama Stack..."
|
||||
fi
|
||||
|
||||
echo "Configuration:"
|
||||
echo " Target: $TARGET"
|
||||
echo " Base URL: $BASE_URL"
|
||||
echo " Duration: ${DURATION}s"
|
||||
echo " Concurrent users: $CONCURRENT"
|
||||
echo ""
|
||||
|
||||
# Create temporary job yaml
|
||||
TEMP_YAML="/tmp/benchmark-job-temp-$(date +%s).yaml"
|
||||
cat > "$TEMP_YAML" << EOF
|
||||
apiVersion: batch/v1
|
||||
kind: Job
|
||||
metadata:
|
||||
name: $JOB_NAME
|
||||
namespace: default
|
||||
spec:
|
||||
template:
|
||||
spec:
|
||||
containers:
|
||||
- name: benchmark
|
||||
image: python:3.11-slim
|
||||
command: ["/bin/bash"]
|
||||
args:
|
||||
- "-c"
|
||||
- |
|
||||
pip install aiohttp &&
|
||||
python3 /benchmark/benchmark.py \\
|
||||
--base-url $BASE_URL \\
|
||||
--model \${INFERENCE_MODEL} \\
|
||||
--duration $DURATION \\
|
||||
--concurrent $CONCURRENT
|
||||
env:
|
||||
- name: INFERENCE_MODEL
|
||||
value: "meta-llama/Llama-3.2-3B-Instruct"
|
||||
volumeMounts:
|
||||
- name: benchmark-script
|
||||
mountPath: /benchmark
|
||||
resources:
|
||||
requests:
|
||||
memory: "256Mi"
|
||||
cpu: "250m"
|
||||
limits:
|
||||
memory: "512Mi"
|
||||
cpu: "500m"
|
||||
volumes:
|
||||
- name: benchmark-script
|
||||
configMap:
|
||||
name: benchmark-script
|
||||
restartPolicy: Never
|
||||
backoffLimit: 3
|
||||
EOF
|
||||
|
||||
echo "Creating benchmark ConfigMap..."
|
||||
kubectl create configmap benchmark-script \
|
||||
--from-file=benchmark.py=benchmark.py \
|
||||
--dry-run=client -o yaml | kubectl apply -f -
|
||||
|
||||
echo "Cleaning up any existing benchmark job..."
|
||||
kubectl delete job $JOB_NAME 2>/dev/null || true
|
||||
|
||||
echo "Deploying benchmark Job..."
|
||||
kubectl apply -f "$TEMP_YAML"
|
||||
|
||||
echo "Waiting for job to start..."
|
||||
kubectl wait --for=condition=Ready pod -l job-name=$JOB_NAME --timeout=60s
|
||||
|
||||
echo "Following benchmark logs..."
|
||||
kubectl logs -f job/$JOB_NAME
|
||||
|
||||
echo "Job completed. Checking final status..."
|
||||
kubectl get job $JOB_NAME
|
||||
|
||||
# Clean up temporary file
|
||||
rm -f "$TEMP_YAML"
|
133
docs/source/distributions/k8s-benchmark/stack-configmap.yaml
Normal file
133
docs/source/distributions/k8s-benchmark/stack-configmap.yaml
Normal file
|
@ -0,0 +1,133 @@
|
|||
apiVersion: v1
|
||||
data:
|
||||
stack_run_config.yaml: |
|
||||
version: '2'
|
||||
image_name: kubernetes-benchmark-demo
|
||||
apis:
|
||||
- agents
|
||||
- inference
|
||||
- safety
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: vllm-inference
|
||||
provider_type: remote::vllm
|
||||
config:
|
||||
url: ${env.VLLM_URL:=http://localhost:8000/v1}
|
||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
- provider_id: vllm-safety
|
||||
provider_type: remote::vllm
|
||||
config:
|
||||
url: ${env.VLLM_SAFETY_URL:=http://localhost:8000/v1}
|
||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
config: {}
|
||||
vector_io:
|
||||
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
|
||||
provider_type: remote::chromadb
|
||||
config:
|
||||
url: ${env.CHROMADB_URL:=}
|
||||
kvstore:
|
||||
type: postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
responses_store:
|
||||
type: postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:+}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:+}
|
||||
max_results: 3
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
config: {}
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
config: {}
|
||||
metadata_store:
|
||||
type: postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
table_name: llamastack_kvstore
|
||||
inference_store:
|
||||
type: postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
models:
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: sentence-transformers
|
||||
model_type: embedding
|
||||
- model_id: ${env.INFERENCE_MODEL}
|
||||
provider_id: vllm-inference
|
||||
model_type: llm
|
||||
- model_id: ${env.SAFETY_MODEL}
|
||||
provider_id: vllm-safety
|
||||
model_type: llm
|
||||
shields:
|
||||
- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8323
|
||||
kind: ConfigMap
|
||||
metadata:
|
||||
creationTimestamp: null
|
||||
name: llama-stack-config
|
|
@ -0,0 +1,83 @@
|
|||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
metadata:
|
||||
name: llama-benchmark-pvc
|
||||
spec:
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
resources:
|
||||
requests:
|
||||
storage: 1Gi
|
||||
---
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: llama-stack-benchmark-server
|
||||
spec:
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
app.kubernetes.io/name: llama-stack-benchmark
|
||||
app.kubernetes.io/component: server
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app.kubernetes.io/name: llama-stack-benchmark
|
||||
app.kubernetes.io/component: server
|
||||
spec:
|
||||
containers:
|
||||
- name: llama-stack-benchmark
|
||||
image: llamastack/distribution-starter:latest
|
||||
imagePullPolicy: Always # since we have specified latest instead of a version
|
||||
env:
|
||||
- name: ENABLE_CHROMADB
|
||||
value: "true"
|
||||
- name: CHROMADB_URL
|
||||
value: http://chromadb.default.svc.cluster.local:6000
|
||||
- name: POSTGRES_HOST
|
||||
value: postgres-server.default.svc.cluster.local
|
||||
- name: POSTGRES_PORT
|
||||
value: "5432"
|
||||
- name: INFERENCE_MODEL
|
||||
value: "${INFERENCE_MODEL}"
|
||||
- name: SAFETY_MODEL
|
||||
value: "${SAFETY_MODEL}"
|
||||
- name: TAVILY_SEARCH_API_KEY
|
||||
value: "${TAVILY_SEARCH_API_KEY}"
|
||||
- name: VLLM_URL
|
||||
value: http://vllm-server.default.svc.cluster.local:8000/v1
|
||||
- name: VLLM_MAX_TOKENS
|
||||
value: "3072"
|
||||
- name: VLLM_SAFETY_URL
|
||||
value: http://vllm-server-safety.default.svc.cluster.local:8001/v1
|
||||
- name: VLLM_TLS_VERIFY
|
||||
value: "false"
|
||||
command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8323"]
|
||||
ports:
|
||||
- containerPort: 8323
|
||||
volumeMounts:
|
||||
- name: llama-storage
|
||||
mountPath: /root/.llama
|
||||
- name: llama-config
|
||||
mountPath: /etc/config
|
||||
volumes:
|
||||
- name: llama-storage
|
||||
persistentVolumeClaim:
|
||||
claimName: llama-benchmark-pvc
|
||||
- name: llama-config
|
||||
configMap:
|
||||
name: llama-stack-config
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: llama-stack-benchmark-service
|
||||
spec:
|
||||
selector:
|
||||
app.kubernetes.io/name: llama-stack-benchmark
|
||||
app.kubernetes.io/component: server
|
||||
ports:
|
||||
- name: http
|
||||
port: 8323
|
||||
targetPort: 8323
|
||||
type: ClusterIP
|
108
docs/source/distributions/k8s-benchmark/stack_run_config.yaml
Normal file
108
docs/source/distributions/k8s-benchmark/stack_run_config.yaml
Normal file
|
@ -0,0 +1,108 @@
|
|||
version: '2'
|
||||
image_name: kubernetes-benchmark-demo
|
||||
apis:
|
||||
- agents
|
||||
- inference
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: vllm-inference
|
||||
provider_type: remote::vllm
|
||||
config:
|
||||
url: ${env.VLLM_URL:=http://localhost:8000/v1}
|
||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
config: {}
|
||||
vector_io:
|
||||
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
|
||||
provider_type: remote::chromadb
|
||||
config:
|
||||
url: ${env.CHROMADB_URL:=}
|
||||
kvstore:
|
||||
type: postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
responses_store:
|
||||
type: postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:+}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:+}
|
||||
max_results: 3
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
config: {}
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
config: {}
|
||||
metadata_store:
|
||||
type: postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
table_name: llamastack_kvstore
|
||||
inference_store:
|
||||
type: postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
models:
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: sentence-transformers
|
||||
model_type: embedding
|
||||
- model_id: ${env.INFERENCE_MODEL}
|
||||
provider_id: vllm-inference
|
||||
model_type: llm
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8323
|
|
@ -40,19 +40,19 @@ spec:
|
|||
value: "3072"
|
||||
- name: VLLM_SAFETY_URL
|
||||
value: http://vllm-server-safety.default.svc.cluster.local:8001/v1
|
||||
- name: VLLM_TLS_VERIFY
|
||||
value: "false"
|
||||
- name: POSTGRES_HOST
|
||||
value: postgres-server.default.svc.cluster.local
|
||||
- name: POSTGRES_PORT
|
||||
value: "5432"
|
||||
- name: VLLM_TLS_VERIFY
|
||||
value: "false"
|
||||
- name: INFERENCE_MODEL
|
||||
value: "${INFERENCE_MODEL}"
|
||||
- name: SAFETY_MODEL
|
||||
value: "${SAFETY_MODEL}"
|
||||
- name: TAVILY_SEARCH_API_KEY
|
||||
value: "${TAVILY_SEARCH_API_KEY}"
|
||||
command: ["python", "-m", "llama_stack.core.server.server", "--config", "/etc/config/stack_run_config.yaml", "--port", "8321"]
|
||||
command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8321"]
|
||||
ports:
|
||||
- containerPort: 8321
|
||||
volumeMounts:
|
||||
|
|
|
@ -2,6 +2,15 @@
|
|||
|
||||
## Overview
|
||||
|
||||
Agents API for creating and interacting with agentic systems.
|
||||
|
||||
Main functionalities provided by this API:
|
||||
- Create agents with specific instructions and ability to use tools.
|
||||
- Interactions with agents are grouped into sessions ("threads"), and each interaction is called a "turn".
|
||||
- Agents can be provided with various tools (see the ToolGroups and ToolRuntime APIs for more details).
|
||||
- Agents can be provided with various shields (see the Safety API for more details).
|
||||
- Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details.
|
||||
|
||||
This section contains documentation for all available providers for the **agents** API.
|
||||
|
||||
## Providers
|
||||
|
|
21
docs/source/providers/batches/index.md
Normal file
21
docs/source/providers/batches/index.md
Normal file
|
@ -0,0 +1,21 @@
|
|||
# Batches
|
||||
|
||||
## Overview
|
||||
|
||||
Protocol for batch processing API operations.
|
||||
|
||||
The Batches API enables efficient processing of multiple requests in a single operation,
|
||||
particularly useful for processing large datasets, batch evaluation workflows, and
|
||||
cost-effective inference at scale.
|
||||
|
||||
Note: This API is currently under active development and may undergo changes.
|
||||
|
||||
This section contains documentation for all available providers for the **batches** API.
|
||||
|
||||
## Providers
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 1
|
||||
|
||||
inline_reference
|
||||
```
|
23
docs/source/providers/batches/inline_reference.md
Normal file
23
docs/source/providers/batches/inline_reference.md
Normal file
|
@ -0,0 +1,23 @@
|
|||
# inline::reference
|
||||
|
||||
## Description
|
||||
|
||||
Reference implementation of batches API with KVStore persistence.
|
||||
|
||||
## Configuration
|
||||
|
||||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Configuration for the key-value store backend. |
|
||||
| `max_concurrent_batches` | `<class 'int'>` | No | 1 | Maximum number of concurrent batches to process simultaneously. |
|
||||
| `max_concurrent_requests_per_batch` | `<class 'int'>` | No | 10 | Maximum number of concurrent requests to process per batch. |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
```yaml
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/batches.db
|
||||
|
||||
```
|
||||
|
|
@ -2,6 +2,8 @@
|
|||
|
||||
## Overview
|
||||
|
||||
Llama Stack Evaluation API for running evaluations on model and agent candidates.
|
||||
|
||||
This section contains documentation for all available providers for the **eval** API.
|
||||
|
||||
## Providers
|
||||
|
|
|
@ -226,7 +226,7 @@ uv init
|
|||
name = "llama-stack-provider-ollama"
|
||||
version = "0.1.0"
|
||||
description = "Ollama provider for Llama Stack"
|
||||
requires-python = ">=3.10"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = ["llama-stack", "pydantic", "ollama", "aiohttp"]
|
||||
```
|
||||
|
||||
|
|
|
@ -2,6 +2,12 @@
|
|||
|
||||
## Overview
|
||||
|
||||
Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||
|
||||
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
||||
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
||||
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||
|
||||
This section contains documentation for all available providers for the **inference** API.
|
||||
|
||||
## Providers
|
||||
|
|
|
@ -21,5 +21,7 @@ kvstore:
|
|||
|
||||
## Deprecation Notice
|
||||
|
||||
⚠️ **Warning**: Please use the `inline::faiss` provider instead.
|
||||
```{warning}
|
||||
Please use the `inline::faiss` provider instead.
|
||||
```
|
||||
|
||||
|
|
|
@ -25,5 +25,7 @@ kvstore:
|
|||
|
||||
## Deprecation Notice
|
||||
|
||||
⚠️ **Warning**: Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.
|
||||
```{warning}
|
||||
Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.
|
||||
```
|
||||
|
||||
|
|
|
@ -204,7 +204,10 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
|
|||
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend |
|
||||
| `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. |
|
||||
|
||||
> **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider.
|
||||
```{note}
|
||||
This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider.
|
||||
```
|
||||
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
|
|
|
@ -128,7 +128,9 @@ llama download --source huggingface --model-id Prompt-Guard-86M --ignore-pattern
|
|||
|
||||
**Important:** Set your environment variable `HF_TOKEN` or pass in `--hf-token` to the command to validate your access. You can find your token at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens).
|
||||
|
||||
> **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored.
|
||||
```{tip}
|
||||
Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored.
|
||||
```
|
||||
|
||||
## List the downloaded models
|
||||
|
||||
|
|
|
@ -152,7 +152,9 @@ llama download --source huggingface --model-id Prompt-Guard-86M --ignore-pattern
|
|||
|
||||
**Important:** Set your environment variable `HF_TOKEN` or pass in `--hf-token` to the command to validate your access. You can find your token at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens).
|
||||
|
||||
> **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored.
|
||||
```{tip}
|
||||
Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored.
|
||||
```
|
||||
|
||||
## List the downloaded models
|
||||
|
||||
|
|
|
@ -706,6 +706,7 @@ class Agents(Protocol):
|
|||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
|
||||
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Create a new OpenAI response.
|
||||
|
@ -713,6 +714,7 @@ class Agents(Protocol):
|
|||
:param input: Input message(s) to create the response.
|
||||
:param model: The underlying LLM used for completions.
|
||||
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
|
||||
:param include: (Optional) Additional fields to include in the response.
|
||||
:returns: An OpenAIResponseObject.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -170,6 +170,23 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
|
|||
type: Literal["web_search_call"] = "web_search_call"
|
||||
|
||||
|
||||
class OpenAIResponseOutputMessageFileSearchToolCallResults(BaseModel):
|
||||
"""Search results returned by the file search operation.
|
||||
|
||||
:param attributes: (Optional) Key-value attributes associated with the file
|
||||
:param file_id: Unique identifier of the file containing the result
|
||||
:param filename: Name of the file containing the result
|
||||
:param score: Relevance score for this search result (between 0 and 1)
|
||||
:param text: Text content of the search result
|
||||
"""
|
||||
|
||||
attributes: dict[str, Any]
|
||||
file_id: str
|
||||
filename: str
|
||||
score: float
|
||||
text: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel):
|
||||
"""File search tool call output message for OpenAI responses.
|
||||
|
@ -185,7 +202,7 @@ class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel):
|
|||
queries: list[str]
|
||||
status: str
|
||||
type: Literal["file_search_call"] = "file_search_call"
|
||||
results: list[dict[str, Any]] | None = None
|
||||
results: list[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -606,6 +623,62 @@ class OpenAIResponseObjectStreamResponseMcpCallCompleted(BaseModel):
|
|||
type: Literal["response.mcp_call.completed"] = "response.mcp_call.completed"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseContentPartOutputText(BaseModel):
|
||||
type: Literal["output_text"] = "output_text"
|
||||
text: str
|
||||
# TODO: add annotations, logprobs, etc.
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseContentPartRefusal(BaseModel):
|
||||
type: Literal["refusal"] = "refusal"
|
||||
refusal: str
|
||||
|
||||
|
||||
OpenAIResponseContentPart = Annotated[
|
||||
OpenAIResponseContentPartOutputText | OpenAIResponseContentPartRefusal,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseContentPart, name="OpenAIResponseContentPart")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseContentPartAdded(BaseModel):
|
||||
"""Streaming event for when a new content part is added to a response item.
|
||||
|
||||
:param response_id: Unique identifier of the response containing this content
|
||||
:param item_id: Unique identifier of the output item containing this content part
|
||||
:param part: The content part that was added
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.content_part.added"
|
||||
"""
|
||||
|
||||
response_id: str
|
||||
item_id: str
|
||||
part: OpenAIResponseContentPart
|
||||
sequence_number: int
|
||||
type: Literal["response.content_part.added"] = "response.content_part.added"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseContentPartDone(BaseModel):
|
||||
"""Streaming event for when a content part is completed.
|
||||
|
||||
:param response_id: Unique identifier of the response containing this content
|
||||
:param item_id: Unique identifier of the output item containing this content part
|
||||
:param part: The completed content part
|
||||
:param sequence_number: Sequential number for ordering streaming events
|
||||
:param type: Event type identifier, always "response.content_part.done"
|
||||
"""
|
||||
|
||||
response_id: str
|
||||
item_id: str
|
||||
part: OpenAIResponseContentPart
|
||||
sequence_number: int
|
||||
type: Literal["response.content_part.done"] = "response.content_part.done"
|
||||
|
||||
|
||||
OpenAIResponseObjectStream = Annotated[
|
||||
OpenAIResponseObjectStreamResponseCreated
|
||||
| OpenAIResponseObjectStreamResponseOutputItemAdded
|
||||
|
@ -625,6 +698,8 @@ OpenAIResponseObjectStream = Annotated[
|
|||
| OpenAIResponseObjectStreamResponseMcpCallInProgress
|
||||
| OpenAIResponseObjectStreamResponseMcpCallFailed
|
||||
| OpenAIResponseObjectStreamResponseMcpCallCompleted
|
||||
| OpenAIResponseObjectStreamResponseContentPartAdded
|
||||
| OpenAIResponseObjectStreamResponseContentPartDone
|
||||
| OpenAIResponseObjectStreamResponseCompleted,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
|
9
llama_stack/apis/batches/__init__.py
Normal file
9
llama_stack/apis/batches/__init__.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .batches import Batches, BatchObject, ListBatchesResponse
|
||||
|
||||
__all__ = ["Batches", "BatchObject", "ListBatchesResponse"]
|
89
llama_stack/apis/batches/batches.py
Normal file
89
llama_stack/apis/batches/batches.py
Normal file
|
@ -0,0 +1,89 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
try:
|
||||
from openai.types import Batch as BatchObject
|
||||
except ImportError as e:
|
||||
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListBatchesResponse(BaseModel):
|
||||
"""Response containing a list of batch objects."""
|
||||
|
||||
object: Literal["list"] = "list"
|
||||
data: list[BatchObject] = Field(..., description="List of batch objects")
|
||||
first_id: str | None = Field(default=None, description="ID of the first batch in the list")
|
||||
last_id: str | None = Field(default=None, description="ID of the last batch in the list")
|
||||
has_more: bool = Field(default=False, description="Whether there are more batches available")
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Batches(Protocol):
|
||||
"""Protocol for batch processing API operations.
|
||||
|
||||
The Batches API enables efficient processing of multiple requests in a single operation,
|
||||
particularly useful for processing large datasets, batch evaluation workflows, and
|
||||
cost-effective inference at scale.
|
||||
|
||||
Note: This API is currently under active development and may undergo changes.
|
||||
"""
|
||||
|
||||
@webmethod(route="/openai/v1/batches", method="POST")
|
||||
async def create_batch(
|
||||
self,
|
||||
input_file_id: str,
|
||||
endpoint: str,
|
||||
completion_window: Literal["24h"],
|
||||
metadata: dict[str, str] | None = None,
|
||||
) -> BatchObject:
|
||||
"""Create a new batch for processing multiple API requests.
|
||||
|
||||
:param input_file_id: The ID of an uploaded file containing requests for the batch.
|
||||
:param endpoint: The endpoint to be used for all requests in the batch.
|
||||
:param completion_window: The time window within which the batch should be processed.
|
||||
:param metadata: Optional metadata for the batch.
|
||||
:returns: The created batch object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/batches/{batch_id}", method="GET")
|
||||
async def retrieve_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Retrieve information about a specific batch.
|
||||
|
||||
:param batch_id: The ID of the batch to retrieve.
|
||||
:returns: The batch object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/batches/{batch_id}/cancel", method="POST")
|
||||
async def cancel_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Cancel a batch that is in progress.
|
||||
|
||||
:param batch_id: The ID of the batch to cancel.
|
||||
:returns: The updated batch object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/batches", method="GET")
|
||||
async def list_batches(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> ListBatchesResponse:
|
||||
"""List all batches for the current user.
|
||||
|
||||
:param after: A cursor for pagination; returns batches after this batch ID.
|
||||
:param limit: Number of batches to return (default 20, max 100).
|
||||
:returns: A list of batch objects.
|
||||
"""
|
||||
...
|
|
@ -62,3 +62,20 @@ class SessionNotFoundError(ValueError):
|
|||
def __init__(self, session_name: str) -> None:
|
||||
message = f"Session '{session_name}' not found or access denied."
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ModelTypeError(TypeError):
|
||||
"""raised when a model is present but not the correct type"""
|
||||
|
||||
def __init__(self, model_name: str, model_type: str, expected_model_type: str) -> None:
|
||||
message = (
|
||||
f"Model '{model_name}' is of type '{model_type}' rather than the expected type '{expected_model_type}'"
|
||||
)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ConflictError(ValueError):
|
||||
"""raised when an operation cannot be performed due to a conflict with the current state"""
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message)
|
||||
|
|
|
@ -86,6 +86,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
|
|||
:cvar inference: Text generation, chat completions, and embeddings
|
||||
:cvar safety: Content moderation and safety shields
|
||||
:cvar agents: Agent orchestration and execution
|
||||
:cvar batches: Batch processing for asynchronous API requests
|
||||
:cvar vector_io: Vector database operations and queries
|
||||
:cvar datasetio: Dataset input/output operations
|
||||
:cvar scoring: Model output evaluation and scoring
|
||||
|
@ -108,6 +109,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
|
|||
inference = "inference"
|
||||
safety = "safety"
|
||||
agents = "agents"
|
||||
batches = "batches"
|
||||
vector_io = "vector_io"
|
||||
datasetio = "datasetio"
|
||||
scoring = "scoring"
|
||||
|
|
|
@ -22,6 +22,7 @@ class OpenAIFilePurpose(StrEnum):
|
|||
"""
|
||||
|
||||
ASSISTANTS = "assistants"
|
||||
BATCH = "batch"
|
||||
# TODO: Add other purposes as needed
|
||||
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum, StrEnum
|
||||
from enum import Enum
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
@ -15,27 +15,6 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
|||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
# OpenAI Categories to return in the response
|
||||
class OpenAICategories(StrEnum):
|
||||
"""
|
||||
Required set of categories in moderations api response
|
||||
"""
|
||||
|
||||
VIOLENCE = "violence"
|
||||
VIOLENCE_GRAPHIC = "violence/graphic"
|
||||
HARRASMENT = "harassment"
|
||||
HARRASMENT_THREATENING = "harassment/threatening"
|
||||
HATE = "hate"
|
||||
HATE_THREATENING = "hate/threatening"
|
||||
ILLICIT = "illicit"
|
||||
ILLICIT_VIOLENT = "illicit/violent"
|
||||
SEXUAL = "sexual"
|
||||
SEXUAL_MINORS = "sexual/minors"
|
||||
SELF_HARM = "self-harm"
|
||||
SELF_HARM_INTENT = "self-harm/intent"
|
||||
SELF_HARM_INSTRUCTIONS = "self-harm/instructions"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ModerationObjectResults(BaseModel):
|
||||
"""A moderation object.
|
||||
|
@ -43,20 +22,6 @@ class ModerationObjectResults(BaseModel):
|
|||
:param categories: A list of the categories, and whether they are flagged or not.
|
||||
:param category_applied_input_types: A list of the categories along with the input type(s) that the score applies to.
|
||||
:param category_scores: A list of the categories along with their scores as predicted by model.
|
||||
Required set of categories that need to be in response
|
||||
- violence
|
||||
- violence/graphic
|
||||
- harassment
|
||||
- harassment/threatening
|
||||
- hate
|
||||
- hate/threatening
|
||||
- illicit
|
||||
- illicit/violent
|
||||
- sexual
|
||||
- sexual/minors
|
||||
- self-harm
|
||||
- self-harm/intent
|
||||
- self-harm/instructions
|
||||
"""
|
||||
|
||||
flagged: bool
|
||||
|
|
|
@ -91,7 +91,7 @@ def get_provider_dependencies(
|
|||
|
||||
|
||||
def print_pip_install_help(config: BuildConfig):
|
||||
normal_deps, special_deps = get_provider_dependencies(config)
|
||||
normal_deps, special_deps, _ = get_provider_dependencies(config)
|
||||
|
||||
cprint(
|
||||
f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}",
|
||||
|
|
|
@ -380,8 +380,17 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
json_content = json.dumps(convert_pydantic_to_json_value(result))
|
||||
|
||||
filtered_body = {k: v for k, v in body.items() if not isinstance(v, LibraryClientUploadFile)}
|
||||
|
||||
status_code = httpx.codes.OK
|
||||
|
||||
if options.method.upper() == "DELETE" and result is None:
|
||||
status_code = httpx.codes.NO_CONTENT
|
||||
|
||||
if status_code == httpx.codes.NO_CONTENT:
|
||||
json_content = ""
|
||||
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
status_code=status_code,
|
||||
content=json_content.encode("utf-8"),
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
|
|
|
@ -8,6 +8,7 @@ import inspect
|
|||
from typing import Any
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.batches import Batches
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
|
@ -75,6 +76,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
|
|||
Api.agents: Agents,
|
||||
Api.inference: Inference,
|
||||
Api.inspect: Inspect,
|
||||
Api.batches: Batches,
|
||||
Api.vector_io: VectorIO,
|
||||
Api.vector_dbs: VectorDBs,
|
||||
Api.models: Models,
|
||||
|
|
|
@ -18,7 +18,7 @@ from llama_stack.apis.common.content_types import (
|
|||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
||||
from llama_stack.apis.inference import (
|
||||
BatchChatCompletionResponse,
|
||||
BatchCompletionResponse,
|
||||
|
@ -65,7 +65,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus, Routin
|
|||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class InferenceRouter(Inference):
|
||||
|
@ -177,6 +177,15 @@ class InferenceRouter(Inference):
|
|||
encoded = self.formatter.encode_content(messages)
|
||||
return len(encoded.tokens) if encoded and encoded.tokens else 0
|
||||
|
||||
async def _get_model(self, model_id: str, expected_model_type: str) -> Model:
|
||||
"""takes a model id and gets model after ensuring that it is accessible and of the correct type"""
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ModelNotFoundError(model_id)
|
||||
if model.model_type != expected_model_type:
|
||||
raise ModelTypeError(model_id, model.model_type, expected_model_type)
|
||||
return model
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -195,11 +204,7 @@ class InferenceRouter(Inference):
|
|||
)
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ModelNotFoundError(model_id)
|
||||
if model.model_type == ModelType.embedding:
|
||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||
model = await self._get_model(model_id, ModelType.llm)
|
||||
if tool_config:
|
||||
if tool_choice and tool_choice != tool_config.tool_choice:
|
||||
raise ValueError("tool_choice and tool_config.tool_choice must match")
|
||||
|
@ -301,11 +306,7 @@ class InferenceRouter(Inference):
|
|||
logger.debug(
|
||||
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
|
||||
)
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ModelNotFoundError(model_id)
|
||||
if model.model_type == ModelType.embedding:
|
||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||
model = await self._get_model(model_id, ModelType.llm)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
params = dict(
|
||||
model_id=model_id,
|
||||
|
@ -355,11 +356,7 @@ class InferenceRouter(Inference):
|
|||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
logger.debug(f"InferenceRouter.embeddings: {model_id}")
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ModelNotFoundError(model_id)
|
||||
if model.model_type == ModelType.llm:
|
||||
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
|
||||
await self._get_model(model_id, ModelType.embedding)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
return await provider.embeddings(
|
||||
model_id=model_id,
|
||||
|
@ -395,12 +392,7 @@ class InferenceRouter(Inference):
|
|||
logger.debug(
|
||||
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
|
||||
)
|
||||
model_obj = await self.routing_table.get_model(model)
|
||||
if model_obj is None:
|
||||
raise ModelNotFoundError(model)
|
||||
if model_obj.model_type == ModelType.embedding:
|
||||
raise ValueError(f"Model '{model}' is an embedding model and does not support completions")
|
||||
|
||||
model_obj = await self._get_model(model, ModelType.llm)
|
||||
params = dict(
|
||||
model=model_obj.identifier,
|
||||
prompt=prompt,
|
||||
|
@ -476,11 +468,7 @@ class InferenceRouter(Inference):
|
|||
logger.debug(
|
||||
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
|
||||
)
|
||||
model_obj = await self.routing_table.get_model(model)
|
||||
if model_obj is None:
|
||||
raise ModelNotFoundError(model)
|
||||
if model_obj.model_type == ModelType.embedding:
|
||||
raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions")
|
||||
model_obj = await self._get_model(model, ModelType.llm)
|
||||
|
||||
# Use the OpenAI client for a bit of extra input validation without
|
||||
# exposing the OpenAI client itself as part of our API surface
|
||||
|
@ -567,12 +555,7 @@ class InferenceRouter(Inference):
|
|||
logger.debug(
|
||||
f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}",
|
||||
)
|
||||
model_obj = await self.routing_table.get_model(model)
|
||||
if model_obj is None:
|
||||
raise ModelNotFoundError(model)
|
||||
if model_obj.model_type != ModelType.embedding:
|
||||
raise ValueError(f"Model '{model}' is not an embedding model")
|
||||
|
||||
model_obj = await self._get_model(model, ModelType.embedding)
|
||||
params = dict(
|
||||
model=model_obj.identifier,
|
||||
input=input,
|
||||
|
@ -871,4 +854,5 @@ class InferenceRouter(Inference):
|
|||
model=model.identifier,
|
||||
object="chat.completion",
|
||||
)
|
||||
logger.debug(f"InferenceRouter.completion_response: {final_response}")
|
||||
await self.store.store_chat_completion(final_response, messages)
|
||||
|
|
|
@ -10,7 +10,7 @@ from llama_stack.apis.inference import (
|
|||
Message,
|
||||
)
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
from llama_stack.apis.safety.safety import ModerationObject, OpenAICategories
|
||||
from llama_stack.apis.safety.safety import ModerationObject
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
@ -82,20 +82,5 @@ class SafetyRouter(Safety):
|
|||
input=input,
|
||||
model=model,
|
||||
)
|
||||
self._validate_required_categories_exist(response)
|
||||
|
||||
return response
|
||||
|
||||
def _validate_required_categories_exist(self, response: ModerationObject) -> None:
|
||||
"""Validate the ProviderImpl response contains the required Open AI moderations categories."""
|
||||
required_categories = list(map(str, OpenAICategories))
|
||||
|
||||
categories = response.results[0].categories
|
||||
category_applied_input_types = response.results[0].category_applied_input_types
|
||||
category_scores = response.results[0].category_scores
|
||||
|
||||
for i in [categories, category_applied_input_types, category_scores]:
|
||||
if not set(required_categories).issubset(set(i.keys())):
|
||||
raise ValueError(
|
||||
f"ProviderImpl response is missing required categories: {set(required_categories) - set(i.keys())}"
|
||||
)
|
||||
|
|
|
@ -63,6 +63,8 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
|
||||
async def get_provider_impl(self, model_id: str) -> Any:
|
||||
model = await lookup_model(self, model_id)
|
||||
if model.provider_id not in self.impls_by_provider_id:
|
||||
raise ValueError(f"Provider {model.provider_id} not found in the routing table")
|
||||
return self.impls_by_provider_id[model.provider_id]
|
||||
|
||||
async def register_model(
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError, VectorStoreNotFoundError
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError, VectorStoreNotFoundError
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||
|
@ -66,7 +66,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
if model is None:
|
||||
raise ModelNotFoundError(embedding_model)
|
||||
if model.model_type != ModelType.embedding:
|
||||
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
||||
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
|
||||
if "embedding_dimension" not in model.metadata:
|
||||
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
||||
vector_db_data = {
|
||||
|
|
|
@ -21,16 +21,18 @@ from importlib.metadata import version as parse_version
|
|||
from pathlib import Path
|
||||
from typing import Annotated, Any, get_origin
|
||||
|
||||
import httpx
|
||||
import rich.pretty
|
||||
import yaml
|
||||
from aiohttp import hdrs
|
||||
from fastapi import Body, FastAPI, HTTPException, Request
|
||||
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||
from fastapi import Path as FastapiPath
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from openai import BadRequestError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.cli.utils import add_config_distro_args, get_config_from_args
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError
|
||||
|
@ -115,7 +117,7 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
|||
|
||||
if isinstance(exc, RequestValidationError):
|
||||
return HTTPException(
|
||||
status_code=400,
|
||||
status_code=httpx.codes.BAD_REQUEST,
|
||||
detail={
|
||||
"errors": [
|
||||
{
|
||||
|
@ -127,21 +129,25 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
|||
]
|
||||
},
|
||||
)
|
||||
elif isinstance(exc, ConflictError):
|
||||
return HTTPException(status_code=409, detail=str(exc))
|
||||
elif isinstance(exc, ResourceNotFoundError):
|
||||
return HTTPException(status_code=404, detail=str(exc))
|
||||
elif isinstance(exc, ValueError):
|
||||
return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}")
|
||||
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}")
|
||||
elif isinstance(exc, BadRequestError):
|
||||
return HTTPException(status_code=400, detail=str(exc))
|
||||
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=str(exc))
|
||||
elif isinstance(exc, PermissionError | AccessDeniedError):
|
||||
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
|
||||
return HTTPException(status_code=httpx.codes.FORBIDDEN, detail=f"Permission denied: {str(exc)}")
|
||||
elif isinstance(exc, asyncio.TimeoutError | TimeoutError):
|
||||
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
|
||||
return HTTPException(status_code=httpx.codes.GATEWAY_TIMEOUT, detail=f"Operation timed out: {str(exc)}")
|
||||
elif isinstance(exc, NotImplementedError):
|
||||
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
|
||||
return HTTPException(status_code=httpx.codes.NOT_IMPLEMENTED, detail=f"Not implemented: {str(exc)}")
|
||||
elif isinstance(exc, AuthenticationRequiredError):
|
||||
return HTTPException(status_code=401, detail=f"Authentication required: {str(exc)}")
|
||||
return HTTPException(status_code=httpx.codes.UNAUTHORIZED, detail=f"Authentication required: {str(exc)}")
|
||||
else:
|
||||
return HTTPException(
|
||||
status_code=500,
|
||||
status_code=httpx.codes.INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error: An unexpected error occurred.",
|
||||
)
|
||||
|
||||
|
@ -180,7 +186,6 @@ async def sse_generator(event_gen_coroutine):
|
|||
event_gen = await event_gen_coroutine
|
||||
async for item in event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Generator cancelled")
|
||||
if event_gen:
|
||||
|
@ -236,6 +241,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
result = await maybe_await(value)
|
||||
if isinstance(result, PaginatedResponse) and result.url is None:
|
||||
result.url = route
|
||||
|
||||
if method.upper() == "DELETE" and result is None:
|
||||
return Response(status_code=httpx.codes.NO_CONTENT)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
|
@ -352,7 +361,7 @@ class ClientVersionMiddleware:
|
|||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 426,
|
||||
"status": httpx.codes.UPGRADE_REQUIRED,
|
||||
"headers": [[b"content-type", b"application/json"]],
|
||||
}
|
||||
)
|
||||
|
|
|
@ -48,6 +48,8 @@ distribution_spec:
|
|||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
- provider_type: remote::model-context-protocol
|
||||
batches:
|
||||
- provider_type: inline::reference
|
||||
image_type: venv
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
|
|
|
@ -2,6 +2,7 @@ version: 2
|
|||
image_name: ci-tests
|
||||
apis:
|
||||
- agents
|
||||
- batches
|
||||
- datasetio
|
||||
- eval
|
||||
- files
|
||||
|
@ -204,6 +205,13 @@ providers:
|
|||
provider_type: inline::rag-runtime
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
batches:
|
||||
- provider_id: reference
|
||||
provider_type: inline::reference
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/batches.db
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/registry.db
|
||||
|
|
|
@ -16,6 +16,7 @@ from llama_stack.distributions.template import DistributionTemplate, RunConfigSe
|
|||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.vector_io.chroma import ChromaVectorIOConfig
|
||||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
|
@ -71,9 +72,10 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
chromadb_provider = Provider(
|
||||
provider_id="chromadb",
|
||||
provider_type="remote::chromadb",
|
||||
config={
|
||||
"url": "${env.CHROMA_URL}",
|
||||
},
|
||||
config=ChromaVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}/",
|
||||
url="${env.CHROMADB_URL:=}",
|
||||
),
|
||||
)
|
||||
|
||||
inference_model = ModelInput(
|
||||
|
|
|
@ -26,7 +26,10 @@ providers:
|
|||
- provider_id: chromadb
|
||||
provider_type: remote::chromadb
|
||||
config:
|
||||
url: ${env.CHROMA_URL}
|
||||
url: ${env.CHROMADB_URL:=}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell/}/chroma_remote_registry.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
|
|
|
@ -22,7 +22,10 @@ providers:
|
|||
- provider_id: chromadb
|
||||
provider_type: remote::chromadb
|
||||
config:
|
||||
url: ${env.CHROMA_URL}
|
||||
url: ${env.CHROMADB_URL:=}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell/}/chroma_remote_registry.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
|
|
|
@ -48,6 +48,8 @@ distribution_spec:
|
|||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
- provider_type: remote::model-context-protocol
|
||||
batches:
|
||||
- provider_type: inline::reference
|
||||
image_type: venv
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
|
|
|
@ -2,6 +2,7 @@ version: 2
|
|||
image_name: starter
|
||||
apis:
|
||||
- agents
|
||||
- batches
|
||||
- datasetio
|
||||
- eval
|
||||
- files
|
||||
|
@ -204,6 +205,13 @@ providers:
|
|||
provider_type: inline::rag-runtime
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
batches:
|
||||
- provider_id: reference
|
||||
provider_type: inline::reference
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/batches.db
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/registry.db
|
||||
|
|
|
@ -139,6 +139,9 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
BuildProvider(provider_type="inline::rag-runtime"),
|
||||
BuildProvider(provider_type="remote::model-context-protocol"),
|
||||
],
|
||||
"batches": [
|
||||
BuildProvider(provider_type="inline::reference"),
|
||||
],
|
||||
}
|
||||
files_provider = Provider(
|
||||
provider_id="meta-reference-files",
|
||||
|
|
|
@ -32,6 +32,7 @@ CATEGORIES = [
|
|||
"tools",
|
||||
"client",
|
||||
"telemetry",
|
||||
"openai_responses",
|
||||
]
|
||||
|
||||
# Initialize category levels with default level
|
||||
|
|
|
@ -236,6 +236,7 @@ class ChatFormat:
|
|||
arguments_json=json.dumps(tool_arguments),
|
||||
)
|
||||
)
|
||||
content = ""
|
||||
|
||||
return RawMessage(
|
||||
role="assistant",
|
||||
|
|
|
@ -48,8 +48,8 @@ from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
|||
|
||||
from .agent_instance import ChatAgent
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
from .openai_responses import OpenAIResponsesImpl
|
||||
from .persistence import AgentInfo
|
||||
from .responses.openai_responses import OpenAIResponsesImpl
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
@ -327,10 +327,21 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
) -> OpenAIResponseObject:
|
||||
return await self.openai_responses_impl.create_openai_response(
|
||||
input, model, instructions, previous_response_id, store, stream, temperature, text, tools, max_infer_iters
|
||||
input,
|
||||
model,
|
||||
instructions,
|
||||
previous_response_id,
|
||||
store,
|
||||
stream,
|
||||
temperature,
|
||||
text,
|
||||
tools,
|
||||
include,
|
||||
max_infer_iters,
|
||||
)
|
||||
|
||||
async def list_openai_responses(
|
||||
|
|
|
@ -1,880 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents import Order
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
AllowedToolsFilter,
|
||||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
OpenAIDeleteResponseObject,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseInputMessageContent,
|
||||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
OpenAIResponseObjectStreamResponseCreated,
|
||||
OpenAIResponseObjectStreamResponseOutputTextDelta,
|
||||
OpenAIResponseOutput,
|
||||
OpenAIResponseOutputMessageContent,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFileSearchToolCall,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
WebSearchToolTypes,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import TextContentItem
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIImageURL,
|
||||
OpenAIJSONSchema,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatJSONObject,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIResponseFormatParam,
|
||||
OpenAIResponseFormatText,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_responses")
|
||||
|
||||
OPENAI_RESPONSES_PREFIX = "openai_responses:"
|
||||
|
||||
|
||||
async def _convert_response_content_to_chat_content(
|
||||
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent],
|
||||
) -> str | list[OpenAIChatCompletionContentPartParam]:
|
||||
"""
|
||||
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
|
||||
|
||||
The content schemas of each API look similar, but are not exactly the same.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
converted_parts = []
|
||||
for content_part in content:
|
||||
if isinstance(content_part, OpenAIResponseInputMessageContentText):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||
elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||
elif isinstance(content_part, OpenAIResponseInputMessageContentImage):
|
||||
if content_part.image_url:
|
||||
image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail)
|
||||
converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
|
||||
elif isinstance(content_part, str):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support content type '{type(content_part)}' in this context"
|
||||
)
|
||||
return converted_parts
|
||||
|
||||
|
||||
async def _convert_response_input_to_chat_messages(
|
||||
input: str | list[OpenAIResponseInput],
|
||||
) -> list[OpenAIMessageParam]:
|
||||
"""
|
||||
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
|
||||
"""
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
if isinstance(input, list):
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
||||
messages.append(
|
||||
OpenAIToolMessageParam(
|
||||
content=input_item.output,
|
||||
tool_call_id=input_item.call_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id=input_item.call_id,
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name=input_item.name,
|
||||
arguments=input_item.arguments,
|
||||
),
|
||||
)
|
||||
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||
else:
|
||||
content = await _convert_response_content_to_chat_content(input_item.content)
|
||||
message_type = await _get_message_type_by_role(input_item.role)
|
||||
if message_type is None:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
||||
)
|
||||
messages.append(message_type(content=content))
|
||||
else:
|
||||
messages.append(OpenAIUserMessageParam(content=input))
|
||||
return messages
|
||||
|
||||
|
||||
async def _convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
|
||||
"""
|
||||
Convert an OpenAI Chat Completion choice into an OpenAI Response output message.
|
||||
"""
|
||||
output_content = ""
|
||||
if isinstance(choice.message.content, str):
|
||||
output_content = choice.message.content
|
||||
elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam):
|
||||
output_content = choice.message.content.text
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
|
||||
)
|
||||
|
||||
return OpenAIResponseMessage(
|
||||
id=f"msg_{uuid.uuid4()}",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
|
||||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
|
||||
async def _convert_response_text_to_chat_response_format(text: OpenAIResponseText) -> OpenAIResponseFormatParam:
|
||||
"""
|
||||
Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format.
|
||||
"""
|
||||
if not text.format or text.format["type"] == "text":
|
||||
return OpenAIResponseFormatText(type="text")
|
||||
if text.format["type"] == "json_object":
|
||||
return OpenAIResponseFormatJSONObject()
|
||||
if text.format["type"] == "json_schema":
|
||||
return OpenAIResponseFormatJSONSchema(
|
||||
json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"])
|
||||
)
|
||||
raise ValueError(f"Unsupported text format: {text.format}")
|
||||
|
||||
|
||||
async def _get_message_type_by_role(role: str):
|
||||
role_to_type = {
|
||||
"user": OpenAIUserMessageParam,
|
||||
"system": OpenAISystemMessageParam,
|
||||
"assistant": OpenAIAssistantMessageParam,
|
||||
"developer": OpenAIDeveloperMessageParam,
|
||||
}
|
||||
return role_to_type.get(role)
|
||||
|
||||
|
||||
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
||||
input_items: ListOpenAIResponseInputItem
|
||||
response: OpenAIResponseObject
|
||||
|
||||
|
||||
class ChatCompletionContext(BaseModel):
|
||||
model: str
|
||||
messages: list[OpenAIMessageParam]
|
||||
response_tools: list[OpenAIResponseInputTool] | None = None
|
||||
chat_tools: list[ChatCompletionToolParam] | None = None
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
|
||||
temperature: float | None
|
||||
response_format: OpenAIResponseFormatParam
|
||||
|
||||
|
||||
class OpenAIResponsesImpl:
|
||||
def __init__(
|
||||
self,
|
||||
inference_api: Inference,
|
||||
tool_groups_api: ToolGroups,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
responses_store: ResponsesStore,
|
||||
vector_io_api: VectorIO, # VectorIO
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.responses_store = responses_store
|
||||
self.vector_io_api = vector_io_api
|
||||
|
||||
async def _prepend_previous_response(
|
||||
self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None
|
||||
):
|
||||
if previous_response_id:
|
||||
previous_response_with_input = await self.responses_store.get_response_object(previous_response_id)
|
||||
|
||||
# previous response input items
|
||||
new_input_items = previous_response_with_input.input
|
||||
|
||||
# previous response output items
|
||||
new_input_items.extend(previous_response_with_input.output)
|
||||
|
||||
# new input items from the current request
|
||||
if isinstance(input, str):
|
||||
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
|
||||
else:
|
||||
new_input_items.extend(input)
|
||||
|
||||
input = new_input_items
|
||||
|
||||
return input
|
||||
|
||||
async def _prepend_instructions(self, messages, instructions):
|
||||
if instructions:
|
||||
messages.insert(0, OpenAISystemMessageParam(content=instructions))
|
||||
|
||||
async def get_openai_response(
|
||||
self,
|
||||
response_id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
response_with_input = await self.responses_store.get_response_object(response_id)
|
||||
return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"})
|
||||
|
||||
async def list_openai_responses(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 50,
|
||||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseObject:
|
||||
return await self.responses_store.list_responses(after, limit, model, order)
|
||||
|
||||
async def list_openai_response_input_items(
|
||||
self,
|
||||
response_id: str,
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
include: list[str] | None = None,
|
||||
limit: int | None = 20,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseInputItem:
|
||||
"""List input items for a given OpenAI response.
|
||||
|
||||
:param response_id: The ID of the response to retrieve input items for.
|
||||
:param after: An item ID to list items after, used for pagination.
|
||||
:param before: An item ID to list items before, used for pagination.
|
||||
:param include: Additional fields to include in the response.
|
||||
:param limit: A limit on the number of objects to be returned.
|
||||
:param order: The order to return the input items in.
|
||||
:returns: An ListOpenAIResponseInputItem.
|
||||
"""
|
||||
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
|
||||
|
||||
async def _store_response(
|
||||
self,
|
||||
response: OpenAIResponseObject,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
) -> None:
|
||||
new_input_id = f"msg_{uuid.uuid4()}"
|
||||
if isinstance(input, str):
|
||||
# synthesize a message from the input string
|
||||
input_content = OpenAIResponseInputMessageContentText(text=input)
|
||||
input_content_item = OpenAIResponseMessage(
|
||||
role="user",
|
||||
content=[input_content],
|
||||
id=new_input_id,
|
||||
)
|
||||
input_items_data = [input_content_item]
|
||||
else:
|
||||
# we already have a list of messages
|
||||
input_items_data = []
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseMessage):
|
||||
# These may or may not already have an id, so dump to dict, check for id, and add if missing
|
||||
input_item_dict = input_item.model_dump()
|
||||
if "id" not in input_item_dict:
|
||||
input_item_dict["id"] = new_input_id
|
||||
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
|
||||
else:
|
||||
input_items_data.append(input_item)
|
||||
|
||||
await self.responses_store.store_response_object(
|
||||
response_object=response,
|
||||
input=input_items_data,
|
||||
)
|
||||
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
):
|
||||
stream = bool(stream)
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||
|
||||
stream_gen = self._create_streaming_response(
|
||||
input=input,
|
||||
model=model,
|
||||
instructions=instructions,
|
||||
previous_response_id=previous_response_id,
|
||||
store=store,
|
||||
temperature=temperature,
|
||||
text=text,
|
||||
tools=tools,
|
||||
max_infer_iters=max_infer_iters,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return stream_gen
|
||||
else:
|
||||
response = None
|
||||
async for stream_chunk in stream_gen:
|
||||
if stream_chunk.type == "response.completed":
|
||||
if response is not None:
|
||||
raise ValueError("The response stream completed multiple times! Earlier response: {response}")
|
||||
response = stream_chunk.response
|
||||
# don't leave the generator half complete!
|
||||
|
||||
if response is None:
|
||||
raise ValueError("The response stream never completed")
|
||||
return response
|
||||
|
||||
async def _create_streaming_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
|
||||
# Input preprocessing
|
||||
input = await self._prepend_previous_response(input, previous_response_id)
|
||||
messages = await _convert_response_input_to_chat_messages(input)
|
||||
await self._prepend_instructions(messages, instructions)
|
||||
|
||||
# Structured outputs
|
||||
response_format = await _convert_response_text_to_chat_response_format(text)
|
||||
|
||||
# Tool setup, TODO: refactor this slightly since this can also yield events
|
||||
chat_tools, mcp_tool_to_server, mcp_list_message = (
|
||||
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
|
||||
)
|
||||
if mcp_list_message:
|
||||
output_messages.append(mcp_list_message)
|
||||
|
||||
ctx = ChatCompletionContext(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_tools=tools,
|
||||
chat_tools=chat_tools,
|
||||
mcp_tool_to_server=mcp_tool_to_server,
|
||||
temperature=temperature,
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
# Create initial response and emit response.created immediately
|
||||
response_id = f"resp-{uuid.uuid4()}"
|
||||
created_at = int(time.time())
|
||||
|
||||
initial_response = OpenAIResponseObject(
|
||||
created_at=created_at,
|
||||
id=response_id,
|
||||
model=model,
|
||||
object="response",
|
||||
status="in_progress",
|
||||
output=output_messages.copy(),
|
||||
text=text,
|
||||
)
|
||||
|
||||
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
|
||||
|
||||
n_iter = 0
|
||||
messages = ctx.messages.copy()
|
||||
|
||||
while True:
|
||||
completion_result = await self.inference_api.openai_chat_completion(
|
||||
model=ctx.model,
|
||||
messages=messages,
|
||||
tools=ctx.chat_tools,
|
||||
stream=True,
|
||||
temperature=ctx.temperature,
|
||||
response_format=ctx.response_format,
|
||||
)
|
||||
|
||||
# Process streaming chunks and build complete response
|
||||
chat_response_id = ""
|
||||
chat_response_content = []
|
||||
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
|
||||
chunk_created = 0
|
||||
chunk_model = ""
|
||||
chunk_finish_reason = ""
|
||||
sequence_number = 0
|
||||
|
||||
# Create a placeholder message item for delta events
|
||||
message_item_id = f"msg_{uuid.uuid4()}"
|
||||
|
||||
async for chunk in completion_result:
|
||||
chat_response_id = chunk.id
|
||||
chunk_created = chunk.created
|
||||
chunk_model = chunk.model
|
||||
for chunk_choice in chunk.choices:
|
||||
# Emit incremental text content as delta events
|
||||
if chunk_choice.delta.content:
|
||||
sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
|
||||
content_index=0,
|
||||
delta=chunk_choice.delta.content,
|
||||
item_id=message_item_id,
|
||||
output_index=0,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
|
||||
# Collect content for final response
|
||||
chat_response_content.append(chunk_choice.delta.content or "")
|
||||
if chunk_choice.finish_reason:
|
||||
chunk_finish_reason = chunk_choice.finish_reason
|
||||
|
||||
# Aggregate tool call arguments across chunks
|
||||
if chunk_choice.delta.tool_calls:
|
||||
for tool_call in chunk_choice.delta.tool_calls:
|
||||
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
|
||||
if response_tool_call:
|
||||
# Don't attempt to concatenate arguments if we don't have any new argumentsAdd commentMore actions
|
||||
if tool_call.function.arguments:
|
||||
# Guard against an initial None argument before we concatenate
|
||||
response_tool_call.function.arguments = (
|
||||
response_tool_call.function.arguments or ""
|
||||
) + tool_call.function.arguments
|
||||
else:
|
||||
tool_call_dict: dict[str, Any] = tool_call.model_dump()
|
||||
tool_call_dict.pop("type", None)
|
||||
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
|
||||
chat_response_tool_calls[tool_call.index] = response_tool_call
|
||||
|
||||
# Convert collected chunks to complete response
|
||||
if chat_response_tool_calls:
|
||||
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
|
||||
else:
|
||||
tool_calls = None
|
||||
assistant_message = OpenAIAssistantMessageParam(
|
||||
content="".join(chat_response_content),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
current_response = OpenAIChatCompletion(
|
||||
id=chat_response_id,
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
message=assistant_message,
|
||||
finish_reason=chunk_finish_reason,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=chunk_created,
|
||||
model=chunk_model,
|
||||
)
|
||||
|
||||
function_tool_calls = []
|
||||
non_function_tool_calls = []
|
||||
|
||||
next_turn_messages = messages.copy()
|
||||
for choice in current_response.choices:
|
||||
next_turn_messages.append(choice.message)
|
||||
|
||||
if choice.message.tool_calls and tools:
|
||||
for tool_call in choice.message.tool_calls:
|
||||
if _is_function_tool_call(tool_call, tools):
|
||||
function_tool_calls.append(tool_call)
|
||||
else:
|
||||
non_function_tool_calls.append(tool_call)
|
||||
else:
|
||||
output_messages.append(await _convert_chat_choice_to_response_message(choice))
|
||||
|
||||
# execute non-function tool calls
|
||||
for tool_call in non_function_tool_calls:
|
||||
tool_call_log, tool_response_message = await self._execute_tool_call(tool_call, ctx)
|
||||
if tool_call_log:
|
||||
output_messages.append(tool_call_log)
|
||||
if tool_response_message:
|
||||
next_turn_messages.append(tool_response_message)
|
||||
|
||||
for tool_call in function_tool_calls:
|
||||
output_messages.append(
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
arguments=tool_call.function.arguments or "",
|
||||
call_id=tool_call.id,
|
||||
name=tool_call.function.name or "",
|
||||
id=f"fc_{uuid.uuid4()}",
|
||||
status="completed",
|
||||
)
|
||||
)
|
||||
|
||||
if not function_tool_calls and not non_function_tool_calls:
|
||||
break
|
||||
|
||||
if function_tool_calls:
|
||||
logger.info("Exiting inference loop since there is a function (client-side) tool call")
|
||||
break
|
||||
|
||||
n_iter += 1
|
||||
if n_iter >= max_infer_iters:
|
||||
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {max_infer_iters=}")
|
||||
break
|
||||
|
||||
messages = next_turn_messages
|
||||
|
||||
# Create final response
|
||||
final_response = OpenAIResponseObject(
|
||||
created_at=created_at,
|
||||
id=response_id,
|
||||
model=model,
|
||||
object="response",
|
||||
status="completed",
|
||||
text=text,
|
||||
output=output_messages,
|
||||
)
|
||||
|
||||
# Emit response.completed
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
|
||||
|
||||
if store:
|
||||
await self._store_response(
|
||||
response=final_response,
|
||||
input=input,
|
||||
)
|
||||
|
||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
return await self.responses_store.delete_response_object(response_id)
|
||||
|
||||
async def _convert_response_tools_to_chat_tools(
|
||||
self, tools: list[OpenAIResponseInputTool]
|
||||
) -> tuple[
|
||||
list[ChatCompletionToolParam],
|
||||
dict[str, OpenAIResponseInputToolMCP],
|
||||
OpenAIResponseOutput | None,
|
||||
]:
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
MCPListToolsTool,
|
||||
)
|
||||
from llama_stack.apis.tools import Tool
|
||||
|
||||
mcp_tool_to_server = {}
|
||||
|
||||
def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
|
||||
tool_def = ToolDefinition(
|
||||
tool_name=tool_name,
|
||||
description=tool.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool.parameters
|
||||
},
|
||||
)
|
||||
return convert_tooldef_to_openai_tool(tool_def)
|
||||
|
||||
mcp_list_message = None
|
||||
chat_tools: list[ChatCompletionToolParam] = []
|
||||
for input_tool in tools:
|
||||
# TODO: Handle other tool types
|
||||
if input_tool.type == "function":
|
||||
chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
|
||||
elif input_tool.type in WebSearchToolTypes:
|
||||
tool_name = "web_search"
|
||||
tool = await self.tool_groups_api.get_tool(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
chat_tools.append(make_openai_tool(tool_name, tool))
|
||||
elif input_tool.type == "file_search":
|
||||
tool_name = "knowledge_search"
|
||||
tool = await self.tool_groups_api.get_tool(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
chat_tools.append(make_openai_tool(tool_name, tool))
|
||||
elif input_tool.type == "mcp":
|
||||
from llama_stack.providers.utils.tools.mcp import list_mcp_tools
|
||||
|
||||
always_allowed = None
|
||||
never_allowed = None
|
||||
if input_tool.allowed_tools:
|
||||
if isinstance(input_tool.allowed_tools, list):
|
||||
always_allowed = input_tool.allowed_tools
|
||||
elif isinstance(input_tool.allowed_tools, AllowedToolsFilter):
|
||||
always_allowed = input_tool.allowed_tools.always
|
||||
never_allowed = input_tool.allowed_tools.never
|
||||
|
||||
tool_defs = await list_mcp_tools(
|
||||
endpoint=input_tool.server_url,
|
||||
headers=input_tool.headers or {},
|
||||
)
|
||||
|
||||
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
||||
id=f"mcp_list_{uuid.uuid4()}",
|
||||
status="completed",
|
||||
server_label=input_tool.server_label,
|
||||
tools=[],
|
||||
)
|
||||
for t in tool_defs.data:
|
||||
if never_allowed and t.name in never_allowed:
|
||||
continue
|
||||
if not always_allowed or t.name in always_allowed:
|
||||
chat_tools.append(make_openai_tool(t.name, t))
|
||||
if t.name in mcp_tool_to_server:
|
||||
raise ValueError(f"Duplicate tool name {t.name} found for server {input_tool.server_label}")
|
||||
mcp_tool_to_server[t.name] = input_tool
|
||||
mcp_list_message.tools.append(
|
||||
MCPListToolsTool(
|
||||
name=t.name,
|
||||
description=t.description,
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
p.name: {
|
||||
"type": p.parameter_type,
|
||||
"description": p.description,
|
||||
}
|
||||
for p in t.parameters
|
||||
},
|
||||
"required": [p.name for p in t.parameters if p.required],
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
||||
return chat_tools, mcp_tool_to_server, mcp_list_message
|
||||
|
||||
async def _execute_knowledge_search_via_vector_store(
|
||||
self,
|
||||
query: str,
|
||||
response_file_search_tool: OpenAIResponseInputToolFileSearch,
|
||||
) -> ToolInvocationResult:
|
||||
"""Execute knowledge search using vector_stores.search API with filters support."""
|
||||
search_results = []
|
||||
|
||||
# Create search tasks for all vector stores
|
||||
async def search_single_store(vector_store_id):
|
||||
try:
|
||||
search_response = await self.vector_io_api.openai_search_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
filters=response_file_search_tool.filters,
|
||||
max_num_results=response_file_search_tool.max_num_results,
|
||||
ranking_options=response_file_search_tool.ranking_options,
|
||||
rewrite_query=False,
|
||||
)
|
||||
return search_response.data
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to search vector store {vector_store_id}: {e}")
|
||||
return []
|
||||
|
||||
# Run all searches in parallel using gather
|
||||
search_tasks = [search_single_store(vid) for vid in response_file_search_tool.vector_store_ids]
|
||||
all_results = await asyncio.gather(*search_tasks)
|
||||
|
||||
# Flatten results
|
||||
for results in all_results:
|
||||
search_results.extend(results)
|
||||
|
||||
# Convert search results to tool result format matching memory.py
|
||||
# Format the results as interleaved content similar to memory.py
|
||||
content_items = []
|
||||
content_items.append(
|
||||
TextContentItem(
|
||||
text=f"knowledge_search tool found {len(search_results)} chunks:\nBEGIN of knowledge_search tool results.\n"
|
||||
)
|
||||
)
|
||||
|
||||
for i, result_item in enumerate(search_results):
|
||||
chunk_text = result_item.content[0].text if result_item.content else ""
|
||||
metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}"
|
||||
if result_item.attributes:
|
||||
metadata_text += f", attributes: {result_item.attributes}"
|
||||
text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n"
|
||||
content_items.append(TextContentItem(text=text_content))
|
||||
|
||||
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
||||
content_items.append(
|
||||
TextContentItem(
|
||||
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n',
|
||||
)
|
||||
)
|
||||
|
||||
return ToolInvocationResult(
|
||||
content=content_items,
|
||||
metadata={
|
||||
"document_ids": [r.file_id for r in search_results],
|
||||
"chunks": [r.content[0].text if r.content else "" for r in search_results],
|
||||
"scores": [r.score for r in search_results],
|
||||
},
|
||||
)
|
||||
|
||||
async def _execute_tool_call(
|
||||
self,
|
||||
tool_call: OpenAIChatCompletionToolCall,
|
||||
ctx: ChatCompletionContext,
|
||||
) -> tuple[OpenAIResponseOutput | None, OpenAIMessageParam | None]:
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
tool_call_id = tool_call.id
|
||||
function = tool_call.function
|
||||
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
|
||||
|
||||
if not function or not tool_call_id or not function.name:
|
||||
return None, None
|
||||
|
||||
error_exc = None
|
||||
result = None
|
||||
try:
|
||||
if ctx.mcp_tool_to_server and function.name in ctx.mcp_tool_to_server:
|
||||
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool
|
||||
|
||||
mcp_tool = ctx.mcp_tool_to_server[function.name]
|
||||
result = await invoke_mcp_tool(
|
||||
endpoint=mcp_tool.server_url,
|
||||
headers=mcp_tool.headers or {},
|
||||
tool_name=function.name,
|
||||
kwargs=tool_kwargs,
|
||||
)
|
||||
elif function.name == "knowledge_search":
|
||||
response_file_search_tool = next(
|
||||
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), None
|
||||
)
|
||||
if response_file_search_tool:
|
||||
# Use vector_stores.search API instead of knowledge_search tool
|
||||
# to support filters and ranking_options
|
||||
query = tool_kwargs.get("query", "")
|
||||
result = await self._execute_knowledge_search_via_vector_store(
|
||||
query=query,
|
||||
response_file_search_tool=response_file_search_tool,
|
||||
)
|
||||
else:
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=function.name,
|
||||
kwargs=tool_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
error_exc = e
|
||||
|
||||
if function.name in ctx.mcp_tool_to_server:
|
||||
from llama_stack.apis.agents.openai_responses import OpenAIResponseOutputMessageMCPCall
|
||||
|
||||
message = OpenAIResponseOutputMessageMCPCall(
|
||||
id=tool_call_id,
|
||||
arguments=function.arguments,
|
||||
name=function.name,
|
||||
server_label=ctx.mcp_tool_to_server[function.name].server_label,
|
||||
)
|
||||
if error_exc:
|
||||
message.error = str(error_exc)
|
||||
elif (result.error_code and result.error_code > 0) or result.error_message:
|
||||
message.error = f"Error (code {result.error_code}): {result.error_message}"
|
||||
elif result.content:
|
||||
message.output = interleaved_content_as_str(result.content)
|
||||
else:
|
||||
if function.name == "web_search":
|
||||
message = OpenAIResponseOutputMessageWebSearchToolCall(
|
||||
id=tool_call_id,
|
||||
status="completed",
|
||||
)
|
||||
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
|
||||
message.status = "failed"
|
||||
elif function.name == "knowledge_search":
|
||||
message = OpenAIResponseOutputMessageFileSearchToolCall(
|
||||
id=tool_call_id,
|
||||
queries=[tool_kwargs.get("query", "")],
|
||||
status="completed",
|
||||
)
|
||||
if "document_ids" in result.metadata:
|
||||
message.results = []
|
||||
for i, doc_id in enumerate(result.metadata["document_ids"]):
|
||||
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
|
||||
score = result.metadata["scores"][i] if "scores" in result.metadata else None
|
||||
message.results.append(
|
||||
{
|
||||
"file_id": doc_id,
|
||||
"filename": doc_id,
|
||||
"text": text,
|
||||
"score": score,
|
||||
}
|
||||
)
|
||||
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
|
||||
message.status = "failed"
|
||||
else:
|
||||
raise ValueError(f"Unknown tool {function.name} called")
|
||||
|
||||
input_message = None
|
||||
if result and result.content:
|
||||
if isinstance(result.content, str):
|
||||
content = result.content
|
||||
elif isinstance(result.content, list):
|
||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||
|
||||
content = []
|
||||
for item in result.content:
|
||||
if isinstance(item, TextContentItem):
|
||||
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
|
||||
elif isinstance(item, ImageContentItem):
|
||||
if item.image.data:
|
||||
url = f"data:image;base64,{item.image.data}"
|
||||
else:
|
||||
url = item.image.url
|
||||
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
|
||||
else:
|
||||
raise ValueError(f"Unknown result content type: {type(item)}")
|
||||
content.append(part)
|
||||
else:
|
||||
raise ValueError(f"Unknown result content type: {type(result.content)}")
|
||||
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
|
||||
else:
|
||||
text = str(error_exc)
|
||||
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
|
||||
|
||||
return message, input_message
|
||||
|
||||
|
||||
def _is_function_tool_call(
|
||||
tool_call: OpenAIChatCompletionToolCall,
|
||||
tools: list[OpenAIResponseInputTool],
|
||||
) -> bool:
|
||||
if not tool_call.function:
|
||||
return False
|
||||
for t in tools:
|
||||
if t.type == "function" and t.name == tool_call.function.name:
|
||||
return True
|
||||
return False
|
|
@ -191,7 +191,11 @@ class AgentPersistence:
|
|||
sessions = []
|
||||
for value in values:
|
||||
try:
|
||||
session_info = Session(**json.loads(value))
|
||||
data = json.loads(value)
|
||||
if "turn_id" in data:
|
||||
continue
|
||||
|
||||
session_info = Session(**data)
|
||||
sessions.append(session_info)
|
||||
except Exception as e:
|
||||
log.error(f"Error parsing session info: {e}")
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -0,0 +1,271 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents import Order
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
OpenAIDeleteResponseObject,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAISystemMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
|
||||
from .streaming import StreamingResponseOrchestrator
|
||||
from .tool_executor import ToolExecutor
|
||||
from .types import ChatCompletionContext
|
||||
from .utils import (
|
||||
convert_response_input_to_chat_messages,
|
||||
convert_response_text_to_chat_response_format,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="responses")
|
||||
|
||||
|
||||
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
||||
input_items: ListOpenAIResponseInputItem
|
||||
response: OpenAIResponseObject
|
||||
|
||||
|
||||
class OpenAIResponsesImpl:
|
||||
def __init__(
|
||||
self,
|
||||
inference_api: Inference,
|
||||
tool_groups_api: ToolGroups,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
responses_store: ResponsesStore,
|
||||
vector_io_api: VectorIO, # VectorIO
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.responses_store = responses_store
|
||||
self.vector_io_api = vector_io_api
|
||||
self.tool_executor = ToolExecutor(
|
||||
tool_groups_api=tool_groups_api,
|
||||
tool_runtime_api=tool_runtime_api,
|
||||
vector_io_api=vector_io_api,
|
||||
)
|
||||
|
||||
async def _prepend_previous_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
previous_response_id: str | None = None,
|
||||
):
|
||||
if previous_response_id:
|
||||
previous_response_with_input = await self.responses_store.get_response_object(previous_response_id)
|
||||
|
||||
# previous response input items
|
||||
new_input_items = previous_response_with_input.input
|
||||
|
||||
# previous response output items
|
||||
new_input_items.extend(previous_response_with_input.output)
|
||||
|
||||
# new input items from the current request
|
||||
if isinstance(input, str):
|
||||
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
|
||||
else:
|
||||
new_input_items.extend(input)
|
||||
|
||||
input = new_input_items
|
||||
|
||||
return input
|
||||
|
||||
async def _prepend_instructions(self, messages, instructions):
|
||||
if instructions:
|
||||
messages.insert(0, OpenAISystemMessageParam(content=instructions))
|
||||
|
||||
async def get_openai_response(
|
||||
self,
|
||||
response_id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
response_with_input = await self.responses_store.get_response_object(response_id)
|
||||
return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"})
|
||||
|
||||
async def list_openai_responses(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 50,
|
||||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseObject:
|
||||
return await self.responses_store.list_responses(after, limit, model, order)
|
||||
|
||||
async def list_openai_response_input_items(
|
||||
self,
|
||||
response_id: str,
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
include: list[str] | None = None,
|
||||
limit: int | None = 20,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseInputItem:
|
||||
"""List input items for a given OpenAI response.
|
||||
|
||||
:param response_id: The ID of the response to retrieve input items for.
|
||||
:param after: An item ID to list items after, used for pagination.
|
||||
:param before: An item ID to list items before, used for pagination.
|
||||
:param include: Additional fields to include in the response.
|
||||
:param limit: A limit on the number of objects to be returned.
|
||||
:param order: The order to return the input items in.
|
||||
:returns: An ListOpenAIResponseInputItem.
|
||||
"""
|
||||
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
|
||||
|
||||
async def _store_response(
|
||||
self,
|
||||
response: OpenAIResponseObject,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
) -> None:
|
||||
new_input_id = f"msg_{uuid.uuid4()}"
|
||||
if isinstance(input, str):
|
||||
# synthesize a message from the input string
|
||||
input_content = OpenAIResponseInputMessageContentText(text=input)
|
||||
input_content_item = OpenAIResponseMessage(
|
||||
role="user",
|
||||
content=[input_content],
|
||||
id=new_input_id,
|
||||
)
|
||||
input_items_data = [input_content_item]
|
||||
else:
|
||||
# we already have a list of messages
|
||||
input_items_data = []
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseMessage):
|
||||
# These may or may not already have an id, so dump to dict, check for id, and add if missing
|
||||
input_item_dict = input_item.model_dump()
|
||||
if "id" not in input_item_dict:
|
||||
input_item_dict["id"] = new_input_id
|
||||
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
|
||||
else:
|
||||
input_items_data.append(input_item)
|
||||
|
||||
await self.responses_store.store_response_object(
|
||||
response_object=response,
|
||||
input=input_items_data,
|
||||
)
|
||||
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
):
|
||||
stream = bool(stream)
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||
|
||||
stream_gen = self._create_streaming_response(
|
||||
input=input,
|
||||
model=model,
|
||||
instructions=instructions,
|
||||
previous_response_id=previous_response_id,
|
||||
store=store,
|
||||
temperature=temperature,
|
||||
text=text,
|
||||
tools=tools,
|
||||
max_infer_iters=max_infer_iters,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return stream_gen
|
||||
else:
|
||||
response = None
|
||||
async for stream_chunk in stream_gen:
|
||||
if stream_chunk.type == "response.completed":
|
||||
if response is not None:
|
||||
raise ValueError("The response stream completed multiple times! Earlier response: {response}")
|
||||
response = stream_chunk.response
|
||||
# don't leave the generator half complete!
|
||||
|
||||
if response is None:
|
||||
raise ValueError("The response stream never completed")
|
||||
return response
|
||||
|
||||
async def _create_streaming_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Input preprocessing
|
||||
input = await self._prepend_previous_response(input, previous_response_id)
|
||||
messages = await convert_response_input_to_chat_messages(input)
|
||||
await self._prepend_instructions(messages, instructions)
|
||||
|
||||
# Structured outputs
|
||||
response_format = await convert_response_text_to_chat_response_format(text)
|
||||
|
||||
ctx = ChatCompletionContext(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_tools=tools,
|
||||
temperature=temperature,
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
# Create orchestrator and delegate streaming logic
|
||||
response_id = f"resp-{uuid.uuid4()}"
|
||||
created_at = int(time.time())
|
||||
|
||||
orchestrator = StreamingResponseOrchestrator(
|
||||
inference_api=self.inference_api,
|
||||
ctx=ctx,
|
||||
response_id=response_id,
|
||||
created_at=created_at,
|
||||
text=text,
|
||||
max_infer_iters=max_infer_iters,
|
||||
tool_executor=self.tool_executor,
|
||||
)
|
||||
|
||||
# Stream the response
|
||||
final_response = None
|
||||
async for stream_chunk in orchestrator.create_response():
|
||||
if stream_chunk.type == "response.completed":
|
||||
final_response = stream_chunk.response
|
||||
yield stream_chunk
|
||||
|
||||
# Store the response if requested
|
||||
if store and final_response:
|
||||
await self._store_response(
|
||||
response=final_response,
|
||||
input=input,
|
||||
)
|
||||
|
||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
return await self.responses_store.delete_response_object(response_id)
|
|
@ -0,0 +1,634 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
AllowedToolsFilter,
|
||||
MCPListToolsTool,
|
||||
OpenAIResponseContentPartOutputText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
OpenAIResponseObjectStreamResponseContentPartAdded,
|
||||
OpenAIResponseObjectStreamResponseContentPartDone,
|
||||
OpenAIResponseObjectStreamResponseCreated,
|
||||
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta,
|
||||
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone,
|
||||
OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta,
|
||||
OpenAIResponseObjectStreamResponseMcpCallArgumentsDone,
|
||||
OpenAIResponseObjectStreamResponseMcpListToolsCompleted,
|
||||
OpenAIResponseObjectStreamResponseMcpListToolsInProgress,
|
||||
OpenAIResponseObjectStreamResponseOutputItemAdded,
|
||||
OpenAIResponseObjectStreamResponseOutputItemDone,
|
||||
OpenAIResponseObjectStreamResponseOutputTextDelta,
|
||||
OpenAIResponseOutput,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseText,
|
||||
WebSearchToolTypes,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChoice,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .types import ChatCompletionContext, ChatCompletionResult
|
||||
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
|
||||
|
||||
logger = get_logger(name=__name__, category="responses")
|
||||
|
||||
|
||||
class StreamingResponseOrchestrator:
|
||||
def __init__(
|
||||
self,
|
||||
inference_api: Inference,
|
||||
ctx: ChatCompletionContext,
|
||||
response_id: str,
|
||||
created_at: int,
|
||||
text: OpenAIResponseText,
|
||||
max_infer_iters: int,
|
||||
tool_executor, # Will be the tool execution logic from the main class
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
self.ctx = ctx
|
||||
self.response_id = response_id
|
||||
self.created_at = created_at
|
||||
self.text = text
|
||||
self.max_infer_iters = max_infer_iters
|
||||
self.tool_executor = tool_executor
|
||||
self.sequence_number = 0
|
||||
# Store MCP tool mapping that gets built during tool processing
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
|
||||
|
||||
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Initialize output messages
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
# Create initial response and emit response.created immediately
|
||||
initial_response = OpenAIResponseObject(
|
||||
created_at=self.created_at,
|
||||
id=self.response_id,
|
||||
model=self.ctx.model,
|
||||
object="response",
|
||||
status="in_progress",
|
||||
output=output_messages.copy(),
|
||||
text=self.text,
|
||||
)
|
||||
|
||||
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
|
||||
|
||||
# Process all tools (including MCP tools) and emit streaming events
|
||||
if self.ctx.response_tools:
|
||||
async for stream_event in self._process_tools(self.ctx.response_tools, output_messages):
|
||||
yield stream_event
|
||||
|
||||
n_iter = 0
|
||||
messages = self.ctx.messages.copy()
|
||||
|
||||
while True:
|
||||
completion_result = await self.inference_api.openai_chat_completion(
|
||||
model=self.ctx.model,
|
||||
messages=messages,
|
||||
tools=self.ctx.chat_tools,
|
||||
stream=True,
|
||||
temperature=self.ctx.temperature,
|
||||
response_format=self.ctx.response_format,
|
||||
)
|
||||
|
||||
# Process streaming chunks and build complete response
|
||||
completion_result_data = None
|
||||
async for stream_event_or_result in self._process_streaming_chunks(completion_result, output_messages):
|
||||
if isinstance(stream_event_or_result, ChatCompletionResult):
|
||||
completion_result_data = stream_event_or_result
|
||||
else:
|
||||
yield stream_event_or_result
|
||||
if not completion_result_data:
|
||||
raise ValueError("Streaming chunk processor failed to return completion data")
|
||||
current_response = self._build_chat_completion(completion_result_data)
|
||||
|
||||
function_tool_calls, non_function_tool_calls, next_turn_messages = self._separate_tool_calls(
|
||||
current_response, messages
|
||||
)
|
||||
|
||||
# Handle choices with no tool calls
|
||||
for choice in current_response.choices:
|
||||
if not (choice.message.tool_calls and self.ctx.response_tools):
|
||||
output_messages.append(await convert_chat_choice_to_response_message(choice))
|
||||
|
||||
# Execute tool calls and coordinate results
|
||||
async for stream_event in self._coordinate_tool_execution(
|
||||
function_tool_calls,
|
||||
non_function_tool_calls,
|
||||
completion_result_data,
|
||||
output_messages,
|
||||
next_turn_messages,
|
||||
):
|
||||
yield stream_event
|
||||
|
||||
if not function_tool_calls and not non_function_tool_calls:
|
||||
break
|
||||
|
||||
if function_tool_calls:
|
||||
logger.info("Exiting inference loop since there is a function (client-side) tool call")
|
||||
break
|
||||
|
||||
n_iter += 1
|
||||
if n_iter >= self.max_infer_iters:
|
||||
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}")
|
||||
break
|
||||
|
||||
messages = next_turn_messages
|
||||
|
||||
# Create final response
|
||||
final_response = OpenAIResponseObject(
|
||||
created_at=self.created_at,
|
||||
id=self.response_id,
|
||||
model=self.ctx.model,
|
||||
object="response",
|
||||
status="completed",
|
||||
text=self.text,
|
||||
output=output_messages,
|
||||
)
|
||||
|
||||
# Emit response.completed
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
|
||||
|
||||
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list]:
|
||||
"""Separate tool calls into function and non-function categories."""
|
||||
function_tool_calls = []
|
||||
non_function_tool_calls = []
|
||||
next_turn_messages = messages.copy()
|
||||
|
||||
for choice in current_response.choices:
|
||||
next_turn_messages.append(choice.message)
|
||||
|
||||
if choice.message.tool_calls and self.ctx.response_tools:
|
||||
for tool_call in choice.message.tool_calls:
|
||||
if is_function_tool_call(tool_call, self.ctx.response_tools):
|
||||
function_tool_calls.append(tool_call)
|
||||
else:
|
||||
non_function_tool_calls.append(tool_call)
|
||||
|
||||
return function_tool_calls, non_function_tool_calls, next_turn_messages
|
||||
|
||||
async def _process_streaming_chunks(
|
||||
self, completion_result, output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream | ChatCompletionResult]:
|
||||
"""Process streaming chunks and emit events, returning completion data."""
|
||||
# Initialize result tracking
|
||||
chat_response_id = ""
|
||||
chat_response_content = []
|
||||
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
|
||||
chunk_created = 0
|
||||
chunk_model = ""
|
||||
chunk_finish_reason = ""
|
||||
|
||||
# Create a placeholder message item for delta events
|
||||
message_item_id = f"msg_{uuid.uuid4()}"
|
||||
# Track tool call items for streaming events
|
||||
tool_call_item_ids: dict[int, str] = {}
|
||||
# Track content parts for streaming events
|
||||
content_part_emitted = False
|
||||
|
||||
async for chunk in completion_result:
|
||||
chat_response_id = chunk.id
|
||||
chunk_created = chunk.created
|
||||
chunk_model = chunk.model
|
||||
for chunk_choice in chunk.choices:
|
||||
# Emit incremental text content as delta events
|
||||
if chunk_choice.delta.content:
|
||||
# Emit content_part.added event for first text chunk
|
||||
if not content_part_emitted:
|
||||
content_part_emitted = True
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseContentPartAdded(
|
||||
response_id=self.response_id,
|
||||
item_id=message_item_id,
|
||||
part=OpenAIResponseContentPartOutputText(
|
||||
text="", # Will be filled incrementally via text deltas
|
||||
),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
|
||||
content_index=0,
|
||||
delta=chunk_choice.delta.content,
|
||||
item_id=message_item_id,
|
||||
output_index=0,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Collect content for final response
|
||||
chat_response_content.append(chunk_choice.delta.content or "")
|
||||
if chunk_choice.finish_reason:
|
||||
chunk_finish_reason = chunk_choice.finish_reason
|
||||
|
||||
# Aggregate tool call arguments across chunks
|
||||
if chunk_choice.delta.tool_calls:
|
||||
for tool_call in chunk_choice.delta.tool_calls:
|
||||
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
|
||||
# Create new tool call entry if this is the first chunk for this index
|
||||
is_new_tool_call = response_tool_call is None
|
||||
if is_new_tool_call:
|
||||
tool_call_dict: dict[str, Any] = tool_call.model_dump()
|
||||
tool_call_dict.pop("type", None)
|
||||
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
|
||||
chat_response_tool_calls[tool_call.index] = response_tool_call
|
||||
|
||||
# Create item ID for this tool call for streaming events
|
||||
tool_call_item_id = f"fc_{uuid.uuid4()}"
|
||||
tool_call_item_ids[tool_call.index] = tool_call_item_id
|
||||
|
||||
# Emit output_item.added event for the new function call
|
||||
self.sequence_number += 1
|
||||
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
|
||||
arguments="", # Will be filled incrementally via delta events
|
||||
call_id=tool_call.id or "",
|
||||
name=tool_call.function.name if tool_call.function else "",
|
||||
id=tool_call_item_id,
|
||||
status="in_progress",
|
||||
)
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||
response_id=self.response_id,
|
||||
item=function_call_item,
|
||||
output_index=len(output_messages),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Stream tool call arguments as they arrive (differentiate between MCP and function calls)
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_call_item_id = tool_call_item_ids[tool_call.index]
|
||||
self.sequence_number += 1
|
||||
|
||||
# Check if this is an MCP tool call
|
||||
is_mcp_tool = tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server
|
||||
if is_mcp_tool:
|
||||
# Emit MCP-specific argument delta event
|
||||
yield OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta(
|
||||
delta=tool_call.function.arguments,
|
||||
item_id=tool_call_item_id,
|
||||
output_index=len(output_messages),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
else:
|
||||
# Emit function call argument delta event
|
||||
yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(
|
||||
delta=tool_call.function.arguments,
|
||||
item_id=tool_call_item_id,
|
||||
output_index=len(output_messages),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Accumulate arguments for final response (only for subsequent chunks)
|
||||
if not is_new_tool_call:
|
||||
response_tool_call.function.arguments = (
|
||||
response_tool_call.function.arguments or ""
|
||||
) + tool_call.function.arguments
|
||||
|
||||
# Emit arguments.done events for completed tool calls (differentiate between MCP and function calls)
|
||||
for tool_call_index in sorted(chat_response_tool_calls.keys()):
|
||||
tool_call_item_id = tool_call_item_ids[tool_call_index]
|
||||
final_arguments = chat_response_tool_calls[tool_call_index].function.arguments or ""
|
||||
tool_call_name = chat_response_tool_calls[tool_call_index].function.name
|
||||
|
||||
# Check if this is an MCP tool call
|
||||
is_mcp_tool = tool_call_name and tool_call_name in self.mcp_tool_to_server
|
||||
self.sequence_number += 1
|
||||
done_event_cls = (
|
||||
OpenAIResponseObjectStreamResponseMcpCallArgumentsDone
|
||||
if is_mcp_tool
|
||||
else OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone
|
||||
)
|
||||
yield done_event_cls(
|
||||
arguments=final_arguments,
|
||||
item_id=tool_call_item_id,
|
||||
output_index=len(output_messages),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Emit content_part.done event if text content was streamed (before content gets cleared)
|
||||
if content_part_emitted:
|
||||
final_text = "".join(chat_response_content)
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseContentPartDone(
|
||||
response_id=self.response_id,
|
||||
item_id=message_item_id,
|
||||
part=OpenAIResponseContentPartOutputText(
|
||||
text=final_text,
|
||||
),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Clear content when there are tool calls (OpenAI spec behavior)
|
||||
if chat_response_tool_calls:
|
||||
chat_response_content = []
|
||||
|
||||
yield ChatCompletionResult(
|
||||
response_id=chat_response_id,
|
||||
content=chat_response_content,
|
||||
tool_calls=chat_response_tool_calls,
|
||||
created=chunk_created,
|
||||
model=chunk_model,
|
||||
finish_reason=chunk_finish_reason,
|
||||
message_item_id=message_item_id,
|
||||
tool_call_item_ids=tool_call_item_ids,
|
||||
content_part_emitted=content_part_emitted,
|
||||
)
|
||||
|
||||
def _build_chat_completion(self, result: ChatCompletionResult) -> OpenAIChatCompletion:
|
||||
"""Build OpenAIChatCompletion from ChatCompletionResult."""
|
||||
# Convert collected chunks to complete response
|
||||
if result.tool_calls:
|
||||
tool_calls = [result.tool_calls[i] for i in sorted(result.tool_calls.keys())]
|
||||
else:
|
||||
tool_calls = None
|
||||
|
||||
assistant_message = OpenAIAssistantMessageParam(
|
||||
content=result.content_text,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
return OpenAIChatCompletion(
|
||||
id=result.response_id,
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
message=assistant_message,
|
||||
finish_reason=result.finish_reason,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=result.created,
|
||||
model=result.model,
|
||||
)
|
||||
|
||||
async def _coordinate_tool_execution(
|
||||
self,
|
||||
function_tool_calls: list,
|
||||
non_function_tool_calls: list,
|
||||
completion_result_data: ChatCompletionResult,
|
||||
output_messages: list[OpenAIResponseOutput],
|
||||
next_turn_messages: list,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Coordinate execution of both function and non-function tool calls."""
|
||||
# Execute non-function tool calls
|
||||
for tool_call in non_function_tool_calls:
|
||||
# Find the item_id for this tool call
|
||||
matching_item_id = None
|
||||
for index, item_id in completion_result_data.tool_call_item_ids.items():
|
||||
response_tool_call = completion_result_data.tool_calls.get(index)
|
||||
if response_tool_call and response_tool_call.id == tool_call.id:
|
||||
matching_item_id = item_id
|
||||
break
|
||||
|
||||
# Use a fallback item_id if not found
|
||||
if not matching_item_id:
|
||||
matching_item_id = f"tc_{uuid.uuid4()}"
|
||||
|
||||
# Execute tool call with streaming
|
||||
tool_call_log = None
|
||||
tool_response_message = None
|
||||
async for result in self.tool_executor.execute_tool_call(
|
||||
tool_call,
|
||||
self.ctx,
|
||||
self.sequence_number,
|
||||
len(output_messages),
|
||||
matching_item_id,
|
||||
self.mcp_tool_to_server,
|
||||
):
|
||||
if result.stream_event:
|
||||
# Forward streaming events
|
||||
self.sequence_number = result.sequence_number
|
||||
yield result.stream_event
|
||||
|
||||
if result.final_output_message is not None:
|
||||
tool_call_log = result.final_output_message
|
||||
tool_response_message = result.final_input_message
|
||||
self.sequence_number = result.sequence_number
|
||||
|
||||
if tool_call_log:
|
||||
output_messages.append(tool_call_log)
|
||||
|
||||
# Emit output_item.done event for completed non-function tool call
|
||||
if matching_item_id:
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
response_id=self.response_id,
|
||||
item=tool_call_log,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
if tool_response_message:
|
||||
next_turn_messages.append(tool_response_message)
|
||||
|
||||
# Execute function tool calls (client-side)
|
||||
for tool_call in function_tool_calls:
|
||||
# Find the item_id for this tool call from our tracking dictionary
|
||||
matching_item_id = None
|
||||
for index, item_id in completion_result_data.tool_call_item_ids.items():
|
||||
response_tool_call = completion_result_data.tool_calls.get(index)
|
||||
if response_tool_call and response_tool_call.id == tool_call.id:
|
||||
matching_item_id = item_id
|
||||
break
|
||||
|
||||
# Use existing item_id or create new one if not found
|
||||
final_item_id = matching_item_id or f"fc_{uuid.uuid4()}"
|
||||
|
||||
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
|
||||
arguments=tool_call.function.arguments or "",
|
||||
call_id=tool_call.id,
|
||||
name=tool_call.function.name or "",
|
||||
id=final_item_id,
|
||||
status="completed",
|
||||
)
|
||||
output_messages.append(function_call_item)
|
||||
|
||||
# Emit output_item.done event for completed function call
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
response_id=self.response_id,
|
||||
item=function_call_item,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
async def _process_tools(
|
||||
self, tools: list[OpenAIResponseInputTool], output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Process all tools and emit appropriate streaming events."""
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from llama_stack.apis.tools import Tool
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
|
||||
def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
|
||||
tool_def = ToolDefinition(
|
||||
tool_name=tool_name,
|
||||
description=tool.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool.parameters
|
||||
},
|
||||
)
|
||||
return convert_tooldef_to_openai_tool(tool_def)
|
||||
|
||||
# Initialize chat_tools if not already set
|
||||
if self.ctx.chat_tools is None:
|
||||
self.ctx.chat_tools = []
|
||||
|
||||
for input_tool in tools:
|
||||
if input_tool.type == "function":
|
||||
self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
|
||||
elif input_tool.type in WebSearchToolTypes:
|
||||
tool_name = "web_search"
|
||||
# Need to access tool_groups_api from tool_executor
|
||||
tool = await self.tool_executor.tool_groups_api.get_tool(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
self.ctx.chat_tools.append(make_openai_tool(tool_name, tool))
|
||||
elif input_tool.type == "file_search":
|
||||
tool_name = "knowledge_search"
|
||||
tool = await self.tool_executor.tool_groups_api.get_tool(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
self.ctx.chat_tools.append(make_openai_tool(tool_name, tool))
|
||||
elif input_tool.type == "mcp":
|
||||
async for stream_event in self._process_mcp_tool(input_tool, output_messages):
|
||||
yield stream_event
|
||||
else:
|
||||
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
||||
|
||||
async def _process_mcp_tool(
|
||||
self, mcp_tool: OpenAIResponseInputToolMCP, output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Process an MCP tool configuration and emit appropriate streaming events."""
|
||||
from llama_stack.providers.utils.tools.mcp import list_mcp_tools
|
||||
|
||||
# Emit mcp_list_tools.in_progress
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseMcpListToolsInProgress(
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
try:
|
||||
# Parse allowed/never allowed tools
|
||||
always_allowed = None
|
||||
never_allowed = None
|
||||
if mcp_tool.allowed_tools:
|
||||
if isinstance(mcp_tool.allowed_tools, list):
|
||||
always_allowed = mcp_tool.allowed_tools
|
||||
elif isinstance(mcp_tool.allowed_tools, AllowedToolsFilter):
|
||||
always_allowed = mcp_tool.allowed_tools.always
|
||||
never_allowed = mcp_tool.allowed_tools.never
|
||||
|
||||
# Call list_mcp_tools
|
||||
tool_defs = await list_mcp_tools(
|
||||
endpoint=mcp_tool.server_url,
|
||||
headers=mcp_tool.headers or {},
|
||||
)
|
||||
|
||||
# Create the MCP list tools message
|
||||
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
||||
id=f"mcp_list_{uuid.uuid4()}",
|
||||
server_label=mcp_tool.server_label,
|
||||
tools=[],
|
||||
)
|
||||
|
||||
# Process tools and update context
|
||||
for t in tool_defs.data:
|
||||
if never_allowed and t.name in never_allowed:
|
||||
continue
|
||||
if not always_allowed or t.name in always_allowed:
|
||||
# Add to chat tools for inference
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
|
||||
tool_def = ToolDefinition(
|
||||
tool_name=t.name,
|
||||
description=t.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in t.parameters
|
||||
},
|
||||
)
|
||||
openai_tool = convert_tooldef_to_openai_tool(tool_def)
|
||||
if self.ctx.chat_tools is None:
|
||||
self.ctx.chat_tools = []
|
||||
self.ctx.chat_tools.append(openai_tool)
|
||||
|
||||
# Add to MCP tool mapping
|
||||
if t.name in self.mcp_tool_to_server:
|
||||
raise ValueError(f"Duplicate tool name {t.name} found for server {mcp_tool.server_label}")
|
||||
self.mcp_tool_to_server[t.name] = mcp_tool
|
||||
|
||||
# Add to MCP list message
|
||||
mcp_list_message.tools.append(
|
||||
MCPListToolsTool(
|
||||
name=t.name,
|
||||
description=t.description,
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
p.name: {
|
||||
"type": p.parameter_type,
|
||||
"description": p.description,
|
||||
}
|
||||
for p in t.parameters
|
||||
},
|
||||
"required": [p.name for p in t.parameters if p.required],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Add the MCP list message to output
|
||||
output_messages.append(mcp_list_message)
|
||||
|
||||
# Emit output_item.added for the MCP list tools message
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||
response_id=self.response_id,
|
||||
item=mcp_list_message,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Emit mcp_list_tools.completed
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseMcpListToolsCompleted(
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Emit output_item.done for the MCP list tools message
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
response_id=self.response_id,
|
||||
item=mcp_list_message,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Emit mcp_list_tools.failed event if needed
|
||||
logger.exception(f"Failed to list MCP tools from {mcp_tool.server_url}: {e}")
|
||||
raise
|
|
@ -0,0 +1,379 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseObjectStreamResponseMcpCallCompleted,
|
||||
OpenAIResponseObjectStreamResponseMcpCallFailed,
|
||||
OpenAIResponseObjectStreamResponseMcpCallInProgress,
|
||||
OpenAIResponseObjectStreamResponseWebSearchCallCompleted,
|
||||
OpenAIResponseObjectStreamResponseWebSearchCallInProgress,
|
||||
OpenAIResponseObjectStreamResponseWebSearchCallSearching,
|
||||
OpenAIResponseOutputMessageFileSearchToolCall,
|
||||
OpenAIResponseOutputMessageFileSearchToolCallResults,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIImageURL,
|
||||
OpenAIToolMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .types import ChatCompletionContext, ToolExecutionResult
|
||||
|
||||
logger = get_logger(name=__name__, category="responses")
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
tool_groups_api: ToolGroups,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
vector_io_api: VectorIO,
|
||||
):
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.vector_io_api = vector_io_api
|
||||
|
||||
async def execute_tool_call(
|
||||
self,
|
||||
tool_call: OpenAIChatCompletionToolCall,
|
||||
ctx: ChatCompletionContext,
|
||||
sequence_number: int,
|
||||
output_index: int,
|
||||
item_id: str,
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> AsyncIterator[ToolExecutionResult]:
|
||||
tool_call_id = tool_call.id
|
||||
function = tool_call.function
|
||||
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
|
||||
|
||||
if not function or not tool_call_id or not function.name:
|
||||
yield ToolExecutionResult(sequence_number=sequence_number)
|
||||
return
|
||||
|
||||
# Emit progress events for tool execution start
|
||||
async for event_result in self._emit_progress_events(
|
||||
function.name, ctx, sequence_number, output_index, item_id, mcp_tool_to_server
|
||||
):
|
||||
sequence_number = event_result.sequence_number
|
||||
yield event_result
|
||||
|
||||
# Execute the actual tool call
|
||||
error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server)
|
||||
|
||||
# Emit completion events for tool execution
|
||||
has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message))
|
||||
async for event_result in self._emit_completion_events(
|
||||
function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server
|
||||
):
|
||||
sequence_number = event_result.sequence_number
|
||||
yield event_result
|
||||
|
||||
# Build result messages from tool execution
|
||||
output_message, input_message = await self._build_result_messages(
|
||||
function, tool_call_id, tool_kwargs, ctx, error_exc, result, has_error, mcp_tool_to_server
|
||||
)
|
||||
|
||||
# Yield the final result
|
||||
yield ToolExecutionResult(
|
||||
sequence_number=sequence_number, final_output_message=output_message, final_input_message=input_message
|
||||
)
|
||||
|
||||
async def _execute_knowledge_search_via_vector_store(
|
||||
self,
|
||||
query: str,
|
||||
response_file_search_tool: OpenAIResponseInputToolFileSearch,
|
||||
) -> ToolInvocationResult:
|
||||
"""Execute knowledge search using vector_stores.search API with filters support."""
|
||||
search_results = []
|
||||
|
||||
# Create search tasks for all vector stores
|
||||
async def search_single_store(vector_store_id):
|
||||
try:
|
||||
search_response = await self.vector_io_api.openai_search_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
filters=response_file_search_tool.filters,
|
||||
max_num_results=response_file_search_tool.max_num_results,
|
||||
ranking_options=response_file_search_tool.ranking_options,
|
||||
rewrite_query=False,
|
||||
)
|
||||
return search_response.data
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to search vector store {vector_store_id}: {e}")
|
||||
return []
|
||||
|
||||
# Run all searches in parallel using gather
|
||||
search_tasks = [search_single_store(vid) for vid in response_file_search_tool.vector_store_ids]
|
||||
all_results = await asyncio.gather(*search_tasks)
|
||||
|
||||
# Flatten results
|
||||
for results in all_results:
|
||||
search_results.extend(results)
|
||||
|
||||
# Convert search results to tool result format matching memory.py
|
||||
# Format the results as interleaved content similar to memory.py
|
||||
content_items = []
|
||||
content_items.append(
|
||||
TextContentItem(
|
||||
text=f"knowledge_search tool found {len(search_results)} chunks:\nBEGIN of knowledge_search tool results.\n"
|
||||
)
|
||||
)
|
||||
|
||||
for i, result_item in enumerate(search_results):
|
||||
chunk_text = result_item.content[0].text if result_item.content else ""
|
||||
metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}"
|
||||
if result_item.attributes:
|
||||
metadata_text += f", attributes: {result_item.attributes}"
|
||||
text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n"
|
||||
content_items.append(TextContentItem(text=text_content))
|
||||
|
||||
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
||||
content_items.append(
|
||||
TextContentItem(
|
||||
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n',
|
||||
)
|
||||
)
|
||||
|
||||
return ToolInvocationResult(
|
||||
content=content_items,
|
||||
metadata={
|
||||
"document_ids": [r.file_id for r in search_results],
|
||||
"chunks": [r.content[0].text if r.content else "" for r in search_results],
|
||||
"scores": [r.score for r in search_results],
|
||||
},
|
||||
)
|
||||
|
||||
async def _emit_progress_events(
|
||||
self,
|
||||
function_name: str,
|
||||
ctx: ChatCompletionContext,
|
||||
sequence_number: int,
|
||||
output_index: int,
|
||||
item_id: str,
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> AsyncIterator[ToolExecutionResult]:
|
||||
"""Emit progress events for tool execution start."""
|
||||
# Emit in_progress event based on tool type (only for tools with specific streaming events)
|
||||
progress_event = None
|
||||
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||
sequence_number += 1
|
||||
progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
elif function_name == "web_search":
|
||||
sequence_number += 1
|
||||
progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
# Note: knowledge_search and other custom tools don't have specific streaming events in OpenAI spec
|
||||
|
||||
if progress_event:
|
||||
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
|
||||
|
||||
# For web search, emit searching event
|
||||
if function_name == "web_search":
|
||||
sequence_number += 1
|
||||
searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
||||
|
||||
async def _execute_tool(
|
||||
self,
|
||||
function_name: str,
|
||||
tool_kwargs: dict,
|
||||
ctx: ChatCompletionContext,
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> tuple[Exception | None, any]:
|
||||
"""Execute the tool and return error exception and result."""
|
||||
error_exc = None
|
||||
result = None
|
||||
|
||||
try:
|
||||
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool
|
||||
|
||||
mcp_tool = mcp_tool_to_server[function_name]
|
||||
result = await invoke_mcp_tool(
|
||||
endpoint=mcp_tool.server_url,
|
||||
headers=mcp_tool.headers or {},
|
||||
tool_name=function_name,
|
||||
kwargs=tool_kwargs,
|
||||
)
|
||||
elif function_name == "knowledge_search":
|
||||
response_file_search_tool = next(
|
||||
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
|
||||
None,
|
||||
)
|
||||
if response_file_search_tool:
|
||||
# Use vector_stores.search API instead of knowledge_search tool
|
||||
# to support filters and ranking_options
|
||||
query = tool_kwargs.get("query", "")
|
||||
result = await self._execute_knowledge_search_via_vector_store(
|
||||
query=query,
|
||||
response_file_search_tool=response_file_search_tool,
|
||||
)
|
||||
else:
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=function_name,
|
||||
kwargs=tool_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
error_exc = e
|
||||
|
||||
return error_exc, result
|
||||
|
||||
async def _emit_completion_events(
|
||||
self,
|
||||
function_name: str,
|
||||
ctx: ChatCompletionContext,
|
||||
sequence_number: int,
|
||||
output_index: int,
|
||||
item_id: str,
|
||||
has_error: bool,
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> AsyncIterator[ToolExecutionResult]:
|
||||
"""Emit completion or failure events for tool execution."""
|
||||
completion_event = None
|
||||
|
||||
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||
sequence_number += 1
|
||||
if has_error:
|
||||
completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
else:
|
||||
completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
elif function_name == "web_search":
|
||||
sequence_number += 1
|
||||
completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
# Note: knowledge_search and other custom tools don't have specific completion events in OpenAI spec
|
||||
|
||||
if completion_event:
|
||||
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)
|
||||
|
||||
async def _build_result_messages(
|
||||
self,
|
||||
function,
|
||||
tool_call_id: str,
|
||||
tool_kwargs: dict,
|
||||
ctx: ChatCompletionContext,
|
||||
error_exc: Exception | None,
|
||||
result: any,
|
||||
has_error: bool,
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> tuple[any, any]:
|
||||
"""Build output and input messages from tool execution results."""
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
# Build output message
|
||||
if mcp_tool_to_server and function.name in mcp_tool_to_server:
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
)
|
||||
|
||||
message = OpenAIResponseOutputMessageMCPCall(
|
||||
id=tool_call_id,
|
||||
arguments=function.arguments,
|
||||
name=function.name,
|
||||
server_label=mcp_tool_to_server[function.name].server_label,
|
||||
)
|
||||
if error_exc:
|
||||
message.error = str(error_exc)
|
||||
elif (result and result.error_code and result.error_code > 0) or (result and result.error_message):
|
||||
message.error = f"Error (code {result.error_code}): {result.error_message}"
|
||||
elif result and result.content:
|
||||
message.output = interleaved_content_as_str(result.content)
|
||||
else:
|
||||
if function.name == "web_search":
|
||||
message = OpenAIResponseOutputMessageWebSearchToolCall(
|
||||
id=tool_call_id,
|
||||
status="completed",
|
||||
)
|
||||
if has_error:
|
||||
message.status = "failed"
|
||||
elif function.name == "knowledge_search":
|
||||
message = OpenAIResponseOutputMessageFileSearchToolCall(
|
||||
id=tool_call_id,
|
||||
queries=[tool_kwargs.get("query", "")],
|
||||
status="completed",
|
||||
)
|
||||
if result and "document_ids" in result.metadata:
|
||||
message.results = []
|
||||
for i, doc_id in enumerate(result.metadata["document_ids"]):
|
||||
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
|
||||
score = result.metadata["scores"][i] if "scores" in result.metadata else None
|
||||
message.results.append(
|
||||
OpenAIResponseOutputMessageFileSearchToolCallResults(
|
||||
file_id=doc_id,
|
||||
filename=doc_id,
|
||||
text=text,
|
||||
score=score,
|
||||
attributes={},
|
||||
)
|
||||
)
|
||||
if has_error:
|
||||
message.status = "failed"
|
||||
else:
|
||||
raise ValueError(f"Unknown tool {function.name} called")
|
||||
|
||||
# Build input message
|
||||
input_message = None
|
||||
if result and result.content:
|
||||
if isinstance(result.content, str):
|
||||
content = result.content
|
||||
elif isinstance(result.content, list):
|
||||
content = []
|
||||
for item in result.content:
|
||||
if isinstance(item, TextContentItem):
|
||||
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
|
||||
elif isinstance(item, ImageContentItem):
|
||||
if item.image.data:
|
||||
url = f"data:image;base64,{item.image.data}"
|
||||
else:
|
||||
url = item.image.url
|
||||
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
|
||||
else:
|
||||
raise ValueError(f"Unknown result content type: {type(item)}")
|
||||
content.append(part)
|
||||
else:
|
||||
raise ValueError(f"Unknown result content type: {type(result.content)}")
|
||||
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
|
||||
else:
|
||||
text = str(error_exc) if error_exc else "Tool execution failed"
|
||||
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
|
||||
|
||||
return message, input_message
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseOutput,
|
||||
)
|
||||
from llama_stack.apis.inference import OpenAIChatCompletionToolCall, OpenAIMessageParam, OpenAIResponseFormatParam
|
||||
|
||||
|
||||
class ToolExecutionResult(BaseModel):
|
||||
"""Result of streaming tool execution."""
|
||||
|
||||
stream_event: OpenAIResponseObjectStream | None = None
|
||||
sequence_number: int
|
||||
final_output_message: OpenAIResponseOutput | None = None
|
||||
final_input_message: OpenAIMessageParam | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatCompletionResult:
|
||||
"""Result of processing streaming chat completion chunks."""
|
||||
|
||||
response_id: str
|
||||
content: list[str]
|
||||
tool_calls: dict[int, OpenAIChatCompletionToolCall]
|
||||
created: int
|
||||
model: str
|
||||
finish_reason: str
|
||||
message_item_id: str # For streaming events
|
||||
tool_call_item_ids: dict[int, str] # For streaming events
|
||||
content_part_emitted: bool # Tracking state
|
||||
|
||||
@property
|
||||
def content_text(self) -> str:
|
||||
"""Get joined content as string."""
|
||||
return "".join(self.content)
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
"""Check if there are any tool calls."""
|
||||
return bool(self.tool_calls)
|
||||
|
||||
|
||||
class ChatCompletionContext(BaseModel):
|
||||
model: str
|
||||
messages: list[OpenAIMessageParam]
|
||||
response_tools: list[OpenAIResponseInputTool] | None = None
|
||||
chat_tools: list[ChatCompletionToolParam] | None = None
|
||||
temperature: float | None
|
||||
response_format: OpenAIResponseFormatParam
|
|
@ -0,0 +1,169 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseInputMessageContent,
|
||||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseOutputMessageContent,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseText,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIImageURL,
|
||||
OpenAIJSONSchema,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatJSONObject,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIResponseFormatParam,
|
||||
OpenAIResponseFormatText,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
|
||||
|
||||
async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
|
||||
"""Convert an OpenAI Chat Completion choice into an OpenAI Response output message."""
|
||||
output_content = ""
|
||||
if isinstance(choice.message.content, str):
|
||||
output_content = choice.message.content
|
||||
elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam):
|
||||
output_content = choice.message.content.text
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
|
||||
)
|
||||
|
||||
return OpenAIResponseMessage(
|
||||
id=f"msg_{uuid.uuid4()}",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
|
||||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
|
||||
async def convert_response_content_to_chat_content(
|
||||
content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
|
||||
) -> str | list[OpenAIChatCompletionContentPartParam]:
|
||||
"""
|
||||
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
|
||||
|
||||
The content schemas of each API look similar, but are not exactly the same.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
converted_parts = []
|
||||
for content_part in content:
|
||||
if isinstance(content_part, OpenAIResponseInputMessageContentText):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||
elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||
elif isinstance(content_part, OpenAIResponseInputMessageContentImage):
|
||||
if content_part.image_url:
|
||||
image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail)
|
||||
converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
|
||||
elif isinstance(content_part, str):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support content type '{type(content_part)}' in this context"
|
||||
)
|
||||
return converted_parts
|
||||
|
||||
|
||||
async def convert_response_input_to_chat_messages(
|
||||
input: str | list[OpenAIResponseInput],
|
||||
) -> list[OpenAIMessageParam]:
|
||||
"""
|
||||
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
|
||||
"""
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
if isinstance(input, list):
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
||||
messages.append(
|
||||
OpenAIToolMessageParam(
|
||||
content=input_item.output,
|
||||
tool_call_id=input_item.call_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id=input_item.call_id,
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name=input_item.name,
|
||||
arguments=input_item.arguments,
|
||||
),
|
||||
)
|
||||
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||
else:
|
||||
content = await convert_response_content_to_chat_content(input_item.content)
|
||||
message_type = await get_message_type_by_role(input_item.role)
|
||||
if message_type is None:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
||||
)
|
||||
messages.append(message_type(content=content))
|
||||
else:
|
||||
messages.append(OpenAIUserMessageParam(content=input))
|
||||
return messages
|
||||
|
||||
|
||||
async def convert_response_text_to_chat_response_format(
|
||||
text: OpenAIResponseText,
|
||||
) -> OpenAIResponseFormatParam:
|
||||
"""
|
||||
Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format.
|
||||
"""
|
||||
if not text.format or text.format["type"] == "text":
|
||||
return OpenAIResponseFormatText(type="text")
|
||||
if text.format["type"] == "json_object":
|
||||
return OpenAIResponseFormatJSONObject()
|
||||
if text.format["type"] == "json_schema":
|
||||
return OpenAIResponseFormatJSONSchema(
|
||||
json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"])
|
||||
)
|
||||
raise ValueError(f"Unsupported text format: {text.format}")
|
||||
|
||||
|
||||
async def get_message_type_by_role(role: str):
|
||||
role_to_type = {
|
||||
"user": OpenAIUserMessageParam,
|
||||
"system": OpenAISystemMessageParam,
|
||||
"assistant": OpenAIAssistantMessageParam,
|
||||
"developer": OpenAIDeveloperMessageParam,
|
||||
}
|
||||
return role_to_type.get(role)
|
||||
|
||||
|
||||
def is_function_tool_call(
|
||||
tool_call: OpenAIChatCompletionToolCall,
|
||||
tools: list[OpenAIResponseInputTool],
|
||||
) -> bool:
|
||||
if not tool_call.function:
|
||||
return False
|
||||
for t in tools:
|
||||
if t.type == "function" and t.name == tool_call.function.name:
|
||||
return True
|
||||
return False
|
5
llama_stack/providers/inline/batches/__init__.py
Normal file
5
llama_stack/providers/inline/batches/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
36
llama_stack/providers/inline/batches/reference/__init__.py
Normal file
36
llama_stack/providers/inline/batches/reference/__init__.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.core.datatypes import AccessRule, Api
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from .batches import ReferenceBatchesImpl
|
||||
from .config import ReferenceBatchesImplConfig
|
||||
|
||||
__all__ = ["ReferenceBatchesImpl", "ReferenceBatchesImplConfig"]
|
||||
|
||||
|
||||
async def get_provider_impl(config: ReferenceBatchesImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
|
||||
kvstore = await kvstore_impl(config.kvstore)
|
||||
inference_api: Inference | None = deps.get(Api.inference)
|
||||
files_api: Files | None = deps.get(Api.files)
|
||||
models_api: Models | None = deps.get(Api.models)
|
||||
|
||||
if inference_api is None:
|
||||
raise ValueError("Inference API is required but not provided in dependencies")
|
||||
if files_api is None:
|
||||
raise ValueError("Files API is required but not provided in dependencies")
|
||||
if models_api is None:
|
||||
raise ValueError("Models API is required but not provided in dependencies")
|
||||
|
||||
impl = ReferenceBatchesImpl(config, inference_api, files_api, models_api, kvstore)
|
||||
await impl.initialize()
|
||||
return impl
|
580
llama_stack/providers/inline/batches/reference/batches.py
Normal file
580
llama_stack/providers/inline/batches/reference/batches.py
Normal file
|
@ -0,0 +1,580 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import itertools
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Any, Literal
|
||||
|
||||
from openai.types.batch import BatchError, Errors
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.batches import Batches, BatchObject, ListBatchesResponse
|
||||
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
||||
from llama_stack.apis.files import Files, OpenAIFilePurpose
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
from .config import ReferenceBatchesImplConfig
|
||||
|
||||
BATCH_PREFIX = "batch:"
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AsyncBytesIO:
|
||||
"""
|
||||
Async-compatible BytesIO wrapper to allow async file-like operations.
|
||||
|
||||
We use this when uploading files to the Files API, as it expects an
|
||||
async file-like object.
|
||||
"""
|
||||
|
||||
def __init__(self, data: bytes):
|
||||
self._buffer = BytesIO(data)
|
||||
|
||||
async def read(self, n=-1):
|
||||
return self._buffer.read(n)
|
||||
|
||||
async def seek(self, pos, whence=0):
|
||||
return self._buffer.seek(pos, whence)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self._buffer.close()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._buffer, name)
|
||||
|
||||
|
||||
class BatchRequest(BaseModel):
|
||||
line_num: int
|
||||
custom_id: str
|
||||
method: str
|
||||
url: str
|
||||
body: dict[str, Any]
|
||||
|
||||
|
||||
def convert_to_openai_message_param(msg: dict[str, Any]) -> OpenAIMessageParam:
|
||||
"""Convert a message dictionary to OpenAIMessageParam based on role."""
|
||||
role = msg.get("role")
|
||||
|
||||
if role == "user":
|
||||
return OpenAIUserMessageParam(**msg)
|
||||
elif role == "system":
|
||||
return OpenAISystemMessageParam(**msg)
|
||||
elif role == "assistant":
|
||||
return OpenAIAssistantMessageParam(**msg)
|
||||
elif role == "tool":
|
||||
return OpenAIToolMessageParam(**msg)
|
||||
elif role == "developer":
|
||||
return OpenAIDeveloperMessageParam(**msg)
|
||||
else:
|
||||
raise ValueError(f"Unknown message role: {role}")
|
||||
|
||||
|
||||
class ReferenceBatchesImpl(Batches):
|
||||
"""Reference implementation of the Batches API.
|
||||
|
||||
This implementation processes batch files by making individual requests
|
||||
to the inference API and generates output files with results.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ReferenceBatchesImplConfig,
|
||||
inference_api: Inference,
|
||||
files_api: Files,
|
||||
models_api: Models,
|
||||
kvstore: KVStore,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.kvstore = kvstore
|
||||
self.inference_api = inference_api
|
||||
self.files_api = files_api
|
||||
self.models_api = models_api
|
||||
self._processing_tasks: dict[str, asyncio.Task] = {}
|
||||
self._batch_semaphore = asyncio.Semaphore(config.max_concurrent_batches)
|
||||
self._update_batch_lock = asyncio.Lock()
|
||||
|
||||
# this is to allow tests to disable background processing
|
||||
self.process_batches = True
|
||||
|
||||
async def initialize(self) -> None:
|
||||
# TODO: start background processing of existing tasks
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shutdown the batches provider."""
|
||||
if self._processing_tasks:
|
||||
# don't cancel tasks - just let them stop naturally on shutdown
|
||||
# cancelling would mark batches as "cancelled" in the database
|
||||
logger.info(f"Shutdown initiated with {len(self._processing_tasks)} active batch processing tasks")
|
||||
|
||||
# TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions
|
||||
async def create_batch(
|
||||
self,
|
||||
input_file_id: str,
|
||||
endpoint: str,
|
||||
completion_window: Literal["24h"],
|
||||
metadata: dict[str, str] | None = None,
|
||||
) -> BatchObject:
|
||||
"""
|
||||
Create a new batch for processing multiple API requests.
|
||||
|
||||
Error handling by levels -
|
||||
0. Input param handling, results in 40x errors before processing, e.g.
|
||||
- Wrong completion_window
|
||||
- Invalid metadata types
|
||||
- Unknown endpoint
|
||||
-> no batch created
|
||||
1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
|
||||
- input_file_id missing
|
||||
- invalid json in file
|
||||
- missing custom_id, method, url, body
|
||||
- invalid model
|
||||
- streaming
|
||||
-> batch created, validation sends to failed status
|
||||
2. Processing errors, result in error_file_id entries, e.g.
|
||||
- Any error returned from inference endpoint
|
||||
-> batch created, goes to completed status
|
||||
"""
|
||||
|
||||
# TODO: set expiration time for garbage collection
|
||||
|
||||
if endpoint not in ["/v1/chat/completions"]:
|
||||
raise ValueError(
|
||||
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint",
|
||||
)
|
||||
|
||||
if completion_window != "24h":
|
||||
raise ValueError(
|
||||
f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window",
|
||||
)
|
||||
|
||||
batch_id = f"batch_{uuid.uuid4().hex[:16]}"
|
||||
current_time = int(time.time())
|
||||
|
||||
batch = BatchObject(
|
||||
id=batch_id,
|
||||
object="batch",
|
||||
endpoint=endpoint,
|
||||
input_file_id=input_file_id,
|
||||
completion_window=completion_window,
|
||||
status="validating",
|
||||
created_at=current_time,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
|
||||
|
||||
if self.process_batches:
|
||||
task = asyncio.create_task(self._process_batch(batch_id))
|
||||
self._processing_tasks[batch_id] = task
|
||||
|
||||
return batch
|
||||
|
||||
async def cancel_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Cancel a batch that is in progress."""
|
||||
batch = await self.retrieve_batch(batch_id)
|
||||
|
||||
if batch.status in ["cancelled", "cancelling"]:
|
||||
return batch
|
||||
|
||||
if batch.status in ["completed", "failed", "expired"]:
|
||||
raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'")
|
||||
|
||||
await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time()))
|
||||
|
||||
if batch_id in self._processing_tasks:
|
||||
self._processing_tasks[batch_id].cancel()
|
||||
# note: task removal and status="cancelled" handled in finally block of _process_batch
|
||||
|
||||
return await self.retrieve_batch(batch_id)
|
||||
|
||||
async def list_batches(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> ListBatchesResponse:
|
||||
"""
|
||||
List all batches, eventually only for the current user.
|
||||
|
||||
With no notion of user, we return all batches.
|
||||
"""
|
||||
batch_values = await self.kvstore.values_in_range("batch:", "batch:\xff")
|
||||
|
||||
batches = []
|
||||
for batch_data in batch_values:
|
||||
if batch_data:
|
||||
batches.append(BatchObject.model_validate_json(batch_data))
|
||||
|
||||
batches.sort(key=lambda b: b.created_at, reverse=True)
|
||||
|
||||
start_idx = 0
|
||||
if after:
|
||||
for i, batch in enumerate(batches):
|
||||
if batch.id == after:
|
||||
start_idx = i + 1
|
||||
break
|
||||
|
||||
page_batches = batches[start_idx : start_idx + limit]
|
||||
has_more = (start_idx + limit) < len(batches)
|
||||
|
||||
first_id = page_batches[0].id if page_batches else None
|
||||
last_id = page_batches[-1].id if page_batches else None
|
||||
|
||||
return ListBatchesResponse(
|
||||
data=page_batches,
|
||||
first_id=first_id,
|
||||
last_id=last_id,
|
||||
has_more=has_more,
|
||||
)
|
||||
|
||||
async def retrieve_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Retrieve information about a specific batch."""
|
||||
batch_data = await self.kvstore.get(f"batch:{batch_id}")
|
||||
if not batch_data:
|
||||
raise ResourceNotFoundError(batch_id, "Batch", "batches.list()")
|
||||
|
||||
return BatchObject.model_validate_json(batch_data)
|
||||
|
||||
async def _update_batch(self, batch_id: str, **updates) -> None:
|
||||
"""Update batch fields in kvstore."""
|
||||
async with self._update_batch_lock:
|
||||
try:
|
||||
batch = await self.retrieve_batch(batch_id)
|
||||
|
||||
# batch processing is async. once cancelling, only allow "cancelled" status updates
|
||||
if batch.status == "cancelling" and updates.get("status") != "cancelled":
|
||||
logger.info(
|
||||
f"Skipping status update for cancelled batch {batch_id}: attempted {updates.get('status')}"
|
||||
)
|
||||
return
|
||||
|
||||
if "errors" in updates:
|
||||
updates["errors"] = updates["errors"].model_dump()
|
||||
|
||||
batch_dict = batch.model_dump()
|
||||
batch_dict.update(updates)
|
||||
|
||||
await self.kvstore.set(f"batch:{batch_id}", json.dumps(batch_dict))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update batch {batch_id}: {e}")
|
||||
|
||||
async def _validate_input(self, batch: BatchObject) -> tuple[list[BatchError], list[BatchRequest]]:
|
||||
"""
|
||||
Read & validate input, return errors and valid input.
|
||||
|
||||
Validation of
|
||||
- input_file_id existance
|
||||
- valid json
|
||||
- custom_id, method, url, body presence and valid
|
||||
- no streaming
|
||||
"""
|
||||
requests: list[BatchRequest] = []
|
||||
errors: list[BatchError] = []
|
||||
try:
|
||||
await self.files_api.openai_retrieve_file(batch.input_file_id)
|
||||
except Exception:
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="invalid_request",
|
||||
line=None,
|
||||
message=f"Cannot find file {batch.input_file_id}.",
|
||||
param="input_file_id",
|
||||
)
|
||||
)
|
||||
return errors, requests
|
||||
|
||||
# TODO(SECURITY): do something about large files
|
||||
file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id)
|
||||
file_content = file_content_response.body.decode("utf-8")
|
||||
for line_num, line in enumerate(file_content.strip().split("\n"), 1):
|
||||
if line.strip(): # skip empty lines
|
||||
try:
|
||||
request = json.loads(line)
|
||||
|
||||
if not isinstance(request, dict):
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="invalid_request",
|
||||
line=line_num,
|
||||
message="Each line must be a JSON dictionary object",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
valid = True
|
||||
|
||||
for param, expected_type, type_string in [
|
||||
("custom_id", str, "string"),
|
||||
("method", str, "string"),
|
||||
("url", str, "string"),
|
||||
("body", dict, "JSON dictionary object"),
|
||||
]:
|
||||
if param not in request:
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="missing_required_parameter",
|
||||
line=line_num,
|
||||
message=f"Missing required parameter: {param}",
|
||||
param=param,
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
elif not isinstance(request[param], expected_type):
|
||||
param_name = "URL" if param == "url" else param.capitalize()
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="invalid_request",
|
||||
line=line_num,
|
||||
message=f"{param_name} must be a {type_string}",
|
||||
param=param,
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
|
||||
if (url := request.get("url")) and isinstance(url, str) and url != batch.endpoint:
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="invalid_url",
|
||||
line=line_num,
|
||||
message="URL provided for this request does not match the batch endpoint",
|
||||
param="url",
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
|
||||
if (body := request.get("body")) and isinstance(body, dict):
|
||||
if body.get("stream", False):
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="streaming_unsupported",
|
||||
line=line_num,
|
||||
message="Streaming is not supported in batch processing",
|
||||
param="body.stream",
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
|
||||
for param, expected_type, type_string in [
|
||||
("model", str, "a string"),
|
||||
# messages is specific to /v1/chat/completions
|
||||
# we could skip validating messages here and let inference fail. however,
|
||||
# that would be a very expensive way to find out messages is wrong.
|
||||
("messages", list, "an array"), # TODO: allow messages to be a string?
|
||||
]:
|
||||
if param not in body:
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="invalid_request",
|
||||
line=line_num,
|
||||
message=f"{param.capitalize()} parameter is required",
|
||||
param=f"body.{param}",
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
elif not isinstance(body[param], expected_type):
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="invalid_request",
|
||||
line=line_num,
|
||||
message=f"{param.capitalize()} must be {type_string}",
|
||||
param=f"body.{param}",
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
|
||||
if "model" in body and isinstance(body["model"], str):
|
||||
try:
|
||||
await self.models_api.get_model(body["model"])
|
||||
except Exception:
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="model_not_found",
|
||||
line=line_num,
|
||||
message=f"Model '{body['model']}' does not exist or is not supported",
|
||||
param="body.model",
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
|
||||
if valid:
|
||||
assert isinstance(url, str), "URL must be a string" # for mypy
|
||||
assert isinstance(body, dict), "Body must be a dictionary" # for mypy
|
||||
requests.append(
|
||||
BatchRequest(
|
||||
line_num=line_num,
|
||||
url=url,
|
||||
method=request["method"],
|
||||
custom_id=request["custom_id"],
|
||||
body=body,
|
||||
),
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="invalid_json_line",
|
||||
line=line_num,
|
||||
message="This line is not parseable as valid JSON.",
|
||||
)
|
||||
)
|
||||
|
||||
return errors, requests
|
||||
|
||||
async def _process_batch(self, batch_id: str) -> None:
|
||||
"""Background task to process a batch of requests."""
|
||||
try:
|
||||
logger.info(f"Starting batch processing for {batch_id}")
|
||||
async with self._batch_semaphore: # semaphore to limit concurrency
|
||||
logger.info(f"Acquired semaphore for batch {batch_id}")
|
||||
await self._process_batch_impl(batch_id)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Batch processing cancelled for {batch_id}")
|
||||
await self._update_batch(batch_id, status="cancelled", cancelled_at=int(time.time()))
|
||||
except Exception as e:
|
||||
logger.error(f"Batch processing failed for {batch_id}: {e}")
|
||||
await self._update_batch(
|
||||
batch_id,
|
||||
status="failed",
|
||||
failed_at=int(time.time()),
|
||||
errors=Errors(data=[BatchError(code="internal_error", message=str(e))]),
|
||||
)
|
||||
finally:
|
||||
self._processing_tasks.pop(batch_id, None)
|
||||
|
||||
async def _process_batch_impl(self, batch_id: str) -> None:
|
||||
"""Implementation of batch processing logic."""
|
||||
errors: list[BatchError] = []
|
||||
batch = await self.retrieve_batch(batch_id)
|
||||
|
||||
errors, requests = await self._validate_input(batch)
|
||||
if errors:
|
||||
await self._update_batch(batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors))
|
||||
logger.info(f"Batch validation failed for {batch_id} with {len(errors)} errors")
|
||||
return
|
||||
|
||||
logger.info(f"Processing {len(requests)} requests for batch {batch_id}")
|
||||
|
||||
total_requests = len(requests)
|
||||
await self._update_batch(
|
||||
batch_id,
|
||||
status="in_progress",
|
||||
request_counts={"total": total_requests, "completed": 0, "failed": 0},
|
||||
)
|
||||
|
||||
error_results = []
|
||||
success_results = []
|
||||
completed_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for chunk in itertools.batched(requests, self.config.max_concurrent_requests_per_batch):
|
||||
# we use a TaskGroup to ensure all process-single-request tasks are canceled when process-batch is cancelled
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
chunk_tasks = [tg.create_task(self._process_single_request(batch_id, request)) for request in chunk]
|
||||
|
||||
chunk_results = await asyncio.gather(*chunk_tasks, return_exceptions=True)
|
||||
|
||||
for result in chunk_results:
|
||||
if isinstance(result, dict) and result.get("error") is not None: # error response from inference
|
||||
failed_count += 1
|
||||
error_results.append(result)
|
||||
elif isinstance(result, dict) and result.get("response") is not None: # successful inference
|
||||
completed_count += 1
|
||||
success_results.append(result)
|
||||
else: # unexpected result
|
||||
failed_count += 1
|
||||
errors.append(BatchError(code="internal_error", message=f"Unexpected result: {result}"))
|
||||
|
||||
await self._update_batch(
|
||||
batch_id,
|
||||
request_counts={"total": total_requests, "completed": completed_count, "failed": failed_count},
|
||||
)
|
||||
|
||||
if errors:
|
||||
await self._update_batch(
|
||||
batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors)
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
output_file_id = await self._create_output_file(batch_id, success_results, "success")
|
||||
await self._update_batch(batch_id, output_file_id=output_file_id)
|
||||
|
||||
error_file_id = await self._create_output_file(batch_id, error_results, "error")
|
||||
await self._update_batch(batch_id, error_file_id=error_file_id)
|
||||
|
||||
await self._update_batch(batch_id, status="completed", completed_at=int(time.time()))
|
||||
|
||||
logger.info(
|
||||
f"Batch processing completed for {batch_id}: {completed_count} completed, {failed_count} failed"
|
||||
)
|
||||
except Exception as e:
|
||||
# note: errors is empty at this point, so we don't lose anything by ignoring it
|
||||
await self._update_batch(
|
||||
batch_id,
|
||||
status="failed",
|
||||
failed_at=int(time.time()),
|
||||
errors=Errors(data=[BatchError(code="output_failed", message=str(e))]),
|
||||
)
|
||||
|
||||
async def _process_single_request(self, batch_id: str, request: BatchRequest) -> dict:
|
||||
"""Process a single request from the batch."""
|
||||
request_id = f"batch_req_{batch_id}_{request.line_num}"
|
||||
|
||||
try:
|
||||
# TODO(SECURITY): review body for security issues
|
||||
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
|
||||
chat_response = await self.inference_api.openai_chat_completion(**request.body)
|
||||
|
||||
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
|
||||
return {
|
||||
"id": request_id,
|
||||
"custom_id": request.custom_id,
|
||||
"response": {
|
||||
"status_code": 200,
|
||||
"request_id": request_id, # TODO: should this be different?
|
||||
"body": chat_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
|
||||
return {
|
||||
"id": request_id,
|
||||
"custom_id": request.custom_id,
|
||||
"error": {"type": "request_failed", "message": str(e)},
|
||||
}
|
||||
|
||||
async def _create_output_file(self, batch_id: str, results: list[dict], file_type: str) -> str:
|
||||
"""
|
||||
Create an output file with batch results.
|
||||
|
||||
This function filters results based on the specified file_type
|
||||
and uploads the file to the Files API.
|
||||
"""
|
||||
output_lines = [json.dumps(result) for result in results]
|
||||
|
||||
with AsyncBytesIO("\n".join(output_lines).encode("utf-8")) as file_buffer:
|
||||
file_buffer.filename = f"{batch_id}_{file_type}.jsonl"
|
||||
uploaded_file = await self.files_api.openai_upload_file(file=file_buffer, purpose=OpenAIFilePurpose.BATCH)
|
||||
return uploaded_file.id
|
40
llama_stack/providers/inline/batches/reference/config.py
Normal file
40
llama_stack/providers/inline/batches/reference/config.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
|
||||
|
||||
class ReferenceBatchesImplConfig(BaseModel):
|
||||
"""Configuration for the Reference Batches implementation."""
|
||||
|
||||
kvstore: KVStoreConfig = Field(
|
||||
description="Configuration for the key-value store backend.",
|
||||
)
|
||||
|
||||
max_concurrent_batches: int = Field(
|
||||
default=1,
|
||||
description="Maximum number of concurrent batches to process simultaneously.",
|
||||
ge=1,
|
||||
)
|
||||
|
||||
max_concurrent_requests_per_batch: int = Field(
|
||||
default=10,
|
||||
description="Maximum number of concurrent requests to process per batch.",
|
||||
ge=1,
|
||||
)
|
||||
|
||||
# TODO: add a max requests per second rate limiter
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="batches.db",
|
||||
),
|
||||
}
|
|
@ -22,7 +22,7 @@ from llama_stack.apis.safety import (
|
|||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults, OpenAICategories
|
||||
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.core.datatypes import Api
|
||||
from llama_stack.models.llama.datatypes import Role
|
||||
|
@ -72,30 +72,6 @@ SAFETY_CATEGORIES_TO_CODE_MAP = {
|
|||
}
|
||||
SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()}
|
||||
|
||||
OPENAI_TO_LLAMA_CATEGORIES_MAP = {
|
||||
OpenAICategories.VIOLENCE: [CAT_VIOLENT_CRIMES],
|
||||
OpenAICategories.VIOLENCE_GRAPHIC: [CAT_VIOLENT_CRIMES],
|
||||
OpenAICategories.HARRASMENT: [CAT_CHILD_EXPLOITATION],
|
||||
OpenAICategories.HARRASMENT_THREATENING: [CAT_VIOLENT_CRIMES, CAT_CHILD_EXPLOITATION],
|
||||
OpenAICategories.HATE: [CAT_HATE],
|
||||
OpenAICategories.HATE_THREATENING: [CAT_HATE, CAT_VIOLENT_CRIMES],
|
||||
OpenAICategories.ILLICIT: [CAT_NON_VIOLENT_CRIMES],
|
||||
OpenAICategories.ILLICIT_VIOLENT: [CAT_VIOLENT_CRIMES, CAT_INDISCRIMINATE_WEAPONS],
|
||||
OpenAICategories.SEXUAL: [CAT_SEX_CRIMES, CAT_SEXUAL_CONTENT],
|
||||
OpenAICategories.SEXUAL_MINORS: [CAT_CHILD_EXPLOITATION],
|
||||
OpenAICategories.SELF_HARM: [CAT_SELF_HARM],
|
||||
OpenAICategories.SELF_HARM_INTENT: [CAT_SELF_HARM],
|
||||
OpenAICategories.SELF_HARM_INSTRUCTIONS: [CAT_SELF_HARM, CAT_SPECIALIZED_ADVICE],
|
||||
# These are custom categories that are not in the OpenAI moderation categories
|
||||
"custom/defamation": [CAT_DEFAMATION],
|
||||
"custom/specialized_advice": [CAT_SPECIALIZED_ADVICE],
|
||||
"custom/privacy_violation": [CAT_PRIVACY],
|
||||
"custom/intellectual_property": [CAT_INTELLECTUAL_PROPERTY],
|
||||
"custom/weapons": [CAT_INDISCRIMINATE_WEAPONS],
|
||||
"custom/elections": [CAT_ELECTIONS],
|
||||
"custom/code_interpreter_abuse": [CAT_CODE_INTERPRETER_ABUSE],
|
||||
}
|
||||
|
||||
|
||||
DEFAULT_LG_V3_SAFETY_CATEGORIES = [
|
||||
CAT_VIOLENT_CRIMES,
|
||||
|
@ -424,9 +400,9 @@ class LlamaGuardShield:
|
|||
ModerationObject with appropriate configuration
|
||||
"""
|
||||
# Set default values for safe case
|
||||
categories = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), False)
|
||||
category_scores = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), 1.0)
|
||||
category_applied_input_types = {key: [] for key in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()}
|
||||
categories = dict.fromkeys(SAFETY_CATEGORIES_TO_CODE_MAP.keys(), False)
|
||||
category_scores = dict.fromkeys(SAFETY_CATEGORIES_TO_CODE_MAP.keys(), 1.0)
|
||||
category_applied_input_types = {key: [] for key in SAFETY_CATEGORIES_TO_CODE_MAP.keys()}
|
||||
flagged = False
|
||||
user_message = None
|
||||
metadata = {}
|
||||
|
@ -453,19 +429,15 @@ class LlamaGuardShield:
|
|||
],
|
||||
)
|
||||
|
||||
# Get OpenAI categories for the unsafe codes
|
||||
openai_categories = []
|
||||
for code in unsafe_code_list:
|
||||
llama_guard_category = SAFETY_CODE_TO_CATEGORIES_MAP[code]
|
||||
openai_categories.extend(
|
||||
k for k, v_l in OPENAI_TO_LLAMA_CATEGORIES_MAP.items() if llama_guard_category in v_l
|
||||
)
|
||||
llama_guard_category = [SAFETY_CODE_TO_CATEGORIES_MAP[code] for code in unsafe_code_list]
|
||||
|
||||
# Update categories for unsafe content
|
||||
categories = {k: k in openai_categories for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
|
||||
category_scores = {k: 1.0 if k in openai_categories else 0.0 for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
|
||||
categories = {k: k in llama_guard_category for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys()}
|
||||
category_scores = {
|
||||
k: 1.0 if k in llama_guard_category else 0.0 for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys()
|
||||
}
|
||||
category_applied_input_types = {
|
||||
k: ["text"] if k in openai_categories else [] for k in OPENAI_TO_LLAMA_CATEGORIES_MAP
|
||||
k: ["text"] if k in llama_guard_category else [] for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys()
|
||||
}
|
||||
flagged = True
|
||||
user_message = CANNED_RESPONSE_TEXT
|
||||
|
|
|
@ -18,6 +18,7 @@ from llama_stack.apis.safety import (
|
|||
ShieldStore,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.safety.safety import ModerationObject
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
|
@ -64,6 +65,9 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
|
||||
return await self.shield.run(messages)
|
||||
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
raise NotImplementedError("run_moderation is not implemented for Prompt Guard")
|
||||
|
||||
|
||||
class PromptGuardShield:
|
||||
def __init__(
|
||||
|
|
|
@ -33,6 +33,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
|
|||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
@ -128,11 +129,12 @@ class FaissIndex(EmbeddingIndex):
|
|||
# Save updated index
|
||||
await self._save_index()
|
||||
|
||||
async def delete_chunk(self, chunk_id: str) -> None:
|
||||
if chunk_id not in self.chunk_ids:
|
||||
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
|
||||
if not set(chunk_ids).issubset(self.chunk_ids):
|
||||
return
|
||||
|
||||
async with self.chunk_id_lock:
|
||||
def remove_chunk(chunk_id: str):
|
||||
index = self.chunk_ids.index(chunk_id)
|
||||
self.index.remove_ids(np.array([index]))
|
||||
|
||||
|
@ -146,6 +148,10 @@ class FaissIndex(EmbeddingIndex):
|
|||
self.chunk_by_index = new_chunk_by_index
|
||||
self.chunk_ids.pop(index)
|
||||
|
||||
async with self.chunk_id_lock:
|
||||
for chunk_id in chunk_ids:
|
||||
remove_chunk(chunk_id)
|
||||
|
||||
await self._save_index()
|
||||
|
||||
async def query_vector(
|
||||
|
@ -297,8 +303,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
|
||||
"""Delete a chunk from a faiss index"""
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete chunks from a faiss index"""
|
||||
faiss_index = self.cache[store_id].index
|
||||
for chunk_id in chunk_ids:
|
||||
await faiss_index.delete_chunk(chunk_id)
|
||||
await faiss_index.delete_chunks(chunks_for_deletion)
|
||||
|
|
|
@ -31,6 +31,7 @@ from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIV
|
|||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
RERANKER_TYPE_RRF,
|
||||
RERANKER_TYPE_WEIGHTED,
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
@ -426,34 +427,36 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def delete_chunk(self, chunk_id: str) -> None:
|
||||
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Remove a chunk from the SQLite vector store."""
|
||||
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
|
||||
|
||||
def _delete_chunk():
|
||||
def _delete_chunks():
|
||||
connection = _create_sqlite_connection(self.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute("BEGIN TRANSACTION")
|
||||
|
||||
# Delete from metadata table
|
||||
cur.execute(f"DELETE FROM {self.metadata_table} WHERE id = ?", (chunk_id,))
|
||||
placeholders = ",".join("?" * len(chunk_ids))
|
||||
cur.execute(f"DELETE FROM {self.metadata_table} WHERE id IN ({placeholders})", chunk_ids)
|
||||
|
||||
# Delete from vector table
|
||||
cur.execute(f"DELETE FROM {self.vector_table} WHERE id = ?", (chunk_id,))
|
||||
cur.execute(f"DELETE FROM {self.vector_table} WHERE id IN ({placeholders})", chunk_ids)
|
||||
|
||||
# Delete from FTS table
|
||||
cur.execute(f"DELETE FROM {self.fts_table} WHERE id = ?", (chunk_id,))
|
||||
cur.execute(f"DELETE FROM {self.fts_table} WHERE id IN ({placeholders})", chunk_ids)
|
||||
|
||||
connection.commit()
|
||||
except Exception as e:
|
||||
connection.rollback()
|
||||
logger.error(f"Error deleting chunk {chunk_id}: {e}")
|
||||
logger.error(f"Error deleting chunks: {e}")
|
||||
raise
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
await asyncio.to_thread(_delete_chunk)
|
||||
await asyncio.to_thread(_delete_chunks)
|
||||
|
||||
|
||||
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
|
@ -551,12 +554,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
|
||||
"""Delete a chunk from a sqlite_vec index."""
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete chunks from a sqlite_vec index."""
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(store_id)
|
||||
|
||||
for chunk_id in chunk_ids:
|
||||
# Use the index's delete_chunk method
|
||||
await index.index.delete_chunk(chunk_id)
|
||||
await index.index.delete_chunks(chunks_for_deletion)
|
||||
|
|
26
llama_stack/providers/registry/batches.py
Normal file
26
llama_stack/providers/registry/batches.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||
|
||||
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.batches,
|
||||
provider_type="inline::reference",
|
||||
pip_packages=["openai"],
|
||||
module="llama_stack.providers.inline.batches.reference",
|
||||
config_class="llama_stack.providers.inline.batches.reference.config.ReferenceBatchesImplConfig",
|
||||
api_dependencies=[
|
||||
Api.inference,
|
||||
Api.files,
|
||||
Api.models,
|
||||
],
|
||||
description="Reference implementation of batches API with KVStore persistence.",
|
||||
),
|
||||
]
|
|
@ -342,6 +342,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
|
|||
""",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
|
@ -350,6 +351,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
|
|||
module="llama_stack.providers.inline.vector_io.chroma",
|
||||
config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
description="""
|
||||
[Chroma](https://www.trychroma.com/) is an inline and remote vector
|
||||
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
|
||||
|
@ -464,6 +466,7 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more
|
|||
""",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
|
@ -731,6 +734,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
|
|||
""",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
|
|
|
@ -235,6 +235,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
|
||||
llama_model = self.get_llama_model(request.model)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
# TODO: tools are never added to the request, so we need to add them here
|
||||
if media_present or not llama_model:
|
||||
input_dict["messages"] = [
|
||||
await convert_message_to_openai_dict(m, download=True) for m in request.messages
|
||||
|
@ -378,6 +379,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
# Fireworks chat completions OpenAI-compatible API does not support
|
||||
# tool calls properly.
|
||||
llama_model = self.get_llama_model(model_obj.provider_resource_id)
|
||||
|
||||
if llama_model:
|
||||
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(
|
||||
self,
|
||||
|
@ -431,4 +433,5 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
user=user,
|
||||
)
|
||||
|
||||
logger.debug(f"fireworks params: {params}")
|
||||
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)
|
||||
|
|
|
@ -457,9 +457,6 @@ class OllamaInferenceAdapter(
|
|||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
model_obj = await self._get_model(model)
|
||||
if model_obj.model_type != ModelType.embedding:
|
||||
raise ValueError(f"Model {model} is not an embedding model")
|
||||
|
||||
if model_obj.provider_resource_id is None:
|
||||
raise ValueError(f"Model {model} has no provider_resource_id set")
|
||||
|
||||
|
|
|
@ -308,9 +308,7 @@ class TGIAdapter(_HfAdapter):
|
|||
if not config.url:
|
||||
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
|
||||
log.info(f"Initializing TGI client with url={config.url}")
|
||||
self.client = AsyncInferenceClient(
|
||||
model=config.url,
|
||||
)
|
||||
self.client = AsyncInferenceClient(model=config.url, provider="hf-inference")
|
||||
endpoint_info = await self.client.get_endpoint_info()
|
||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||
self.model_id = endpoint_info["model_id"]
|
||||
|
|
|
@ -26,6 +26,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
|
|||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
@ -146,8 +147,10 @@ class ChromaIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def delete_chunk(self, chunk_id: str) -> None:
|
||||
await maybe_await(self.collection.delete([chunk_id]))
|
||||
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete a single chunk from the Chroma collection by its ID."""
|
||||
ids = [f"{chunk.document_id}:{chunk.chunk_id}" for chunk in chunks_for_deletion]
|
||||
await maybe_await(self.collection.delete(ids=ids))
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
|
@ -175,6 +178,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.cache = {}
|
||||
self.kvstore: KVStore | None = None
|
||||
self.vector_db_store = None
|
||||
self.files_api = files_api
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
|
@ -258,5 +262,10 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.cache[vector_db_id] = index
|
||||
return index
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete chunks from a Chroma vector store."""
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {store_id} not found")
|
||||
|
||||
await index.index.delete_chunks(chunks_for_deletion)
|
||||
|
|
|
@ -28,6 +28,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore
|
|||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
RERANKER_TYPE_WEIGHTED,
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
@ -287,14 +288,17 @@ class MilvusIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
|
||||
|
||||
async def delete_chunk(self, chunk_id: str) -> None:
|
||||
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Remove a chunk from the Milvus collection."""
|
||||
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
|
||||
try:
|
||||
# Use IN clause with square brackets and single quotes for VARCHAR field
|
||||
chunk_ids_str = ", ".join(f"'{chunk_id}'" for chunk_id in chunk_ids)
|
||||
await asyncio.to_thread(
|
||||
self.client.delete, collection_name=self.collection_name, filter=f'chunk_id == "{chunk_id}"'
|
||||
self.client.delete, collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting chunk {chunk_id} from Milvus collection {self.collection_name}: {e}")
|
||||
logger.error(f"Error deleting chunks from Milvus collection {self.collection_name}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
|
@ -420,12 +424,10 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete a chunk from a milvus vector store."""
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(store_id)
|
||||
|
||||
for chunk_id in chunk_ids:
|
||||
# Use the index's delete_chunk method
|
||||
await index.index.delete_chunk(chunk_id)
|
||||
await index.index.delete_chunks(chunks_for_deletion)
|
||||
|
|
|
@ -27,6 +27,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
|
|||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
@ -163,10 +164,11 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
|
||||
async def delete_chunk(self, chunk_id: str) -> None:
|
||||
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Remove a chunk from the PostgreSQL table."""
|
||||
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id = %s", (chunk_id,))
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids,))
|
||||
|
||||
|
||||
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
|
@ -275,12 +277,10 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete a chunk from a PostgreSQL vector store."""
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(store_id)
|
||||
|
||||
for chunk_id in chunk_ids:
|
||||
# Use the index's delete_chunk method
|
||||
await index.index.delete_chunk(chunk_id)
|
||||
await index.index.delete_chunks(chunks_for_deletion)
|
||||
|
|
|
@ -29,6 +29,7 @@ from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig a
|
|||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
@ -88,15 +89,16 @@ class QdrantIndex(EmbeddingIndex):
|
|||
|
||||
await self.client.upsert(collection_name=self.collection_name, points=points)
|
||||
|
||||
async def delete_chunk(self, chunk_id: str) -> None:
|
||||
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Remove a chunk from the Qdrant collection."""
|
||||
chunk_ids = [convert_id(c.chunk_id) for c in chunks_for_deletion]
|
||||
try:
|
||||
await self.client.delete(
|
||||
collection_name=self.collection_name,
|
||||
points_selector=models.PointIdsList(points=[convert_id(chunk_id)]),
|
||||
points_selector=models.PointIdsList(points=chunk_ids),
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting chunk {chunk_id} from Qdrant collection {self.collection_name}: {e}")
|
||||
log.error(f"Error deleting chunks from Qdrant collection {self.collection_name}: {e}")
|
||||
raise
|
||||
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
|
@ -264,12 +266,14 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
) -> VectorStoreFileObject:
|
||||
# Qdrant doesn't allow multiple clients to access the same storage path simultaneously.
|
||||
async with self._qdrant_lock:
|
||||
await super().openai_attach_file_to_vector_store(vector_store_id, file_id, attributes, chunking_strategy)
|
||||
return await super().openai_attach_file_to_vector_store(
|
||||
vector_store_id, file_id, attributes, chunking_strategy
|
||||
)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete chunks from a Qdrant vector store."""
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {store_id} not found")
|
||||
for chunk_id in chunk_ids:
|
||||
await index.index.delete_chunk(chunk_id)
|
||||
|
||||
await index.index.delete_chunks(chunks_for_deletion)
|
||||
|
|
|
@ -26,6 +26,7 @@ from llama_stack.providers.utils.memory.openai_vector_store_mixin import (
|
|||
OpenAIVectorStoreMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
@ -67,6 +68,7 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
data_objects.append(
|
||||
wvc.data.DataObject(
|
||||
properties={
|
||||
"chunk_id": chunk.chunk_id,
|
||||
"chunk_content": chunk.model_dump_json(),
|
||||
},
|
||||
vector=embeddings[i].tolist(),
|
||||
|
@ -79,10 +81,11 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
# TODO: make this async friendly
|
||||
collection.data.insert_many(data_objects)
|
||||
|
||||
async def delete_chunk(self, chunk_id: str) -> None:
|
||||
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
|
||||
collection = self.client.collections.get(sanitized_collection_name)
|
||||
collection.data.delete_many(where=Filter.by_property("id").contains_any([chunk_id]))
|
||||
chunk_ids = [chunk.chunk_id for chunk in chunks_for_deletion]
|
||||
collection.data.delete_many(where=Filter.by_property("chunk_id").contains_any(chunk_ids))
|
||||
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
|
||||
|
@ -307,10 +310,10 @@ class WeaviateVectorIOAdapter(
|
|||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
sanitized_collection_name = sanitize_collection_name(store_id, weaviate_format=True)
|
||||
index = await self._get_and_cache_vector_db_index(sanitized_collection_name)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {sanitized_collection_name} not found")
|
||||
|
||||
await index.delete(chunk_ids)
|
||||
await index.index.delete_chunks(chunks_for_deletion)
|
||||
|
|
|
@ -31,15 +31,15 @@ from openai.types.chat import (
|
|||
from openai.types.chat import (
|
||||
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageFunctionToolCall as OpenAIChatCompletionMessageFunctionToolCall,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageToolCall,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
|
||||
)
|
||||
|
@ -70,7 +70,7 @@ from openai.types.chat.chat_completion_chunk import (
|
|||
from openai.types.chat.chat_completion_content_part_image_param import (
|
||||
ImageURL as OpenAIImageURL,
|
||||
)
|
||||
from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
Function as OpenAIFunction,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
@ -633,7 +633,7 @@ async def convert_message_to_openai_dict_new(
|
|||
)
|
||||
elif isinstance(message, CompletionMessage):
|
||||
tool_calls = [
|
||||
OpenAIChatCompletionMessageToolCall(
|
||||
OpenAIChatCompletionMessageFunctionToolCall(
|
||||
id=tool.call_id,
|
||||
function=OpenAIFunction(
|
||||
name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
|
||||
|
@ -903,7 +903,7 @@ def _convert_openai_request_response_format(
|
|||
|
||||
|
||||
def _convert_openai_tool_calls(
|
||||
tool_calls: list[OpenAIChatCompletionMessageToolCall],
|
||||
tool_calls: list[OpenAIChatCompletionMessageFunctionToolCall],
|
||||
) -> list[ToolCall]:
|
||||
"""
|
||||
Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall.
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import time
|
||||
import uuid
|
||||
|
@ -37,10 +36,15 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreSearchResponse,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
content_from_data_and_mime_type,
|
||||
make_overlapped_chunks,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__, category="vector_io")
|
||||
|
||||
# Constants for OpenAI vector stores
|
||||
CHUNK_MULTIPLIER = 5
|
||||
|
@ -154,8 +158,8 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
|
||||
@abstractmethod
|
||||
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
|
||||
"""Delete a chunk from a vector store."""
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete chunks from a vector store."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
@ -614,7 +618,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
)
|
||||
vector_store_file_object.status = "completed"
|
||||
except Exception as e:
|
||||
logger.error(f"Error attaching file to vector store: {e}")
|
||||
logger.exception("Error attaching file to vector store")
|
||||
vector_store_file_object.status = "failed"
|
||||
vector_store_file_object.last_error = VectorStoreFileLastError(
|
||||
code="server_error",
|
||||
|
@ -767,7 +771,21 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
|
||||
dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id)
|
||||
chunks = [Chunk.model_validate(c) for c in dict_chunks]
|
||||
await self.delete_chunks(vector_store_id, [str(c.chunk_id) for c in chunks if c.chunk_id])
|
||||
|
||||
# Create ChunkForDeletion objects with both chunk_id and document_id
|
||||
chunks_for_deletion = []
|
||||
for c in chunks:
|
||||
if c.chunk_id:
|
||||
document_id = c.metadata.get("document_id") or (
|
||||
c.chunk_metadata.document_id if c.chunk_metadata else None
|
||||
)
|
||||
if document_id:
|
||||
chunks_for_deletion.append(ChunkForDeletion(chunk_id=str(c.chunk_id), document_id=document_id))
|
||||
else:
|
||||
logger.warning(f"Chunk {c.chunk_id} has no document_id, skipping deletion")
|
||||
|
||||
if chunks_for_deletion:
|
||||
await self.delete_chunks(vector_store_id, chunks_for_deletion)
|
||||
|
||||
store_info = self.openai_vector_stores[vector_store_id].copy()
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ from urllib.parse import unquote
|
|||
import httpx
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
|
@ -34,6 +35,18 @@ from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
|||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChunkForDeletion(BaseModel):
|
||||
"""Information needed to delete a chunk from a vector store.
|
||||
|
||||
:param chunk_id: The ID of the chunk to delete
|
||||
:param document_id: The ID of the document this chunk belongs to
|
||||
"""
|
||||
|
||||
chunk_id: str
|
||||
document_id: str
|
||||
|
||||
|
||||
# Constants for reranker types
|
||||
RERANKER_TYPE_RRF = "rrf"
|
||||
RERANKER_TYPE_WEIGHTED = "weighted"
|
||||
|
@ -232,7 +245,7 @@ class EmbeddingIndex(ABC):
|
|||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def delete_chunk(self, chunk_id: str):
|
||||
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
|
|
1
llama_stack/ui/.nvmrc
Normal file
1
llama_stack/ui/.nvmrc
Normal file
|
@ -0,0 +1 @@
|
|||
22.5.1
|
|
@ -1,3 +1,12 @@
|
|||
# Ignore artifacts:
|
||||
build
|
||||
coverage
|
||||
.next
|
||||
node_modules
|
||||
dist
|
||||
*.lock
|
||||
*.log
|
||||
|
||||
# Generated files
|
||||
*.min.js
|
||||
*.min.css
|
||||
|
|
|
@ -1 +1,10 @@
|
|||
{}
|
||||
{
|
||||
"semi": true,
|
||||
"trailingComma": "es5",
|
||||
"singleQuote": false,
|
||||
"printWidth": 80,
|
||||
"tabWidth": 2,
|
||||
"useTabs": false,
|
||||
"bracketSpacing": true,
|
||||
"arrowParens": "avoid"
|
||||
}
|
||||
|
|
|
@ -47,7 +47,7 @@ async function proxyRequest(request: NextRequest, method: string) {
|
|||
const responseText = await response.text();
|
||||
|
||||
console.log(
|
||||
`Response from FastAPI: ${response.status} ${response.statusText}`,
|
||||
`Response from FastAPI: ${response.status} ${response.statusText}`
|
||||
);
|
||||
|
||||
// Create response with same status and headers
|
||||
|
@ -74,7 +74,7 @@ async function proxyRequest(request: NextRequest, method: string) {
|
|||
backend_url: BACKEND_URL,
|
||||
timestamp: new Date().toISOString(),
|
||||
},
|
||||
{ status: 500 },
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -51,9 +51,9 @@ export default function SignInPage() {
|
|||
onClick={() => {
|
||||
console.log("Signing in with GitHub...");
|
||||
signIn("github", { callbackUrl: "/auth/signin" }).catch(
|
||||
(error) => {
|
||||
error => {
|
||||
console.error("Sign in error:", error);
|
||||
},
|
||||
}
|
||||
);
|
||||
}}
|
||||
className="w-full"
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue