diff --git a/.github/actions/run-and-record-tests/action.yml b/.github/actions/run-and-record-tests/action.yml index 573148e46..1406c6077 100644 --- a/.github/actions/run-and-record-tests/action.yml +++ b/.github/actions/run-and-record-tests/action.yml @@ -2,9 +2,13 @@ name: 'Run and Record Tests' description: 'Run integration tests and handle recording/artifact upload' inputs: - test-types: - description: 'JSON array of test types to run' + test-subdirs: + description: 'Comma-separated list of test subdirectories to run' required: true + test-pattern: + description: 'Regex pattern to pass to pytest -k' + required: false + default: '' stack-config: description: 'Stack configuration to use' required: true @@ -35,9 +39,11 @@ runs: ./scripts/integration-tests.sh \ --stack-config '${{ inputs.stack-config }}' \ --provider '${{ inputs.provider }}' \ - --test-types '${{ inputs.test-types }}' \ + --test-subdirs '${{ inputs.test-subdirs }}' \ + --test-pattern '${{ inputs.test-pattern }}' \ --inference-mode '${{ inputs.inference-mode }}' \ - ${{ inputs.run-vision-tests == 'true' && '--run-vision-tests' || '' }} + ${{ inputs.run-vision-tests == 'true' && '--run-vision-tests' || '' }} \ + | tee pytest-${{ inputs.inference-mode }}.log - name: Commit and push recordings @@ -57,10 +63,10 @@ runs: git commit -m "Recordings update from CI" fi - git fetch origin ${{ github.event.pull_request.head.ref }} - git rebase origin/${{ github.event.pull_request.head.ref }} + git fetch origin ${{ github.ref_name }} + git rebase origin/${{ github.ref_name }} echo "Rebased successfully" - git push origin HEAD:${{ github.event.pull_request.head.ref }} + git push origin HEAD:${{ github.ref_name }} echo "Pushed successfully" else echo "No recording changes" diff --git a/.github/actions/setup-runner/action.yml b/.github/actions/setup-runner/action.yml index 0be999fe2..1ca02bbff 100644 --- a/.github/actions/setup-runner/action.yml +++ b/.github/actions/setup-runner/action.yml @@ -28,7 +28,7 @@ runs: # Install llama-stack-client-python based on the client-version input if [ "${{ inputs.client-version }}" = "latest" ]; then echo "Installing latest llama-stack-client-python from main branch" - uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main + uv pip install git+https://github.com/llamastack/llama-stack-client-python.git@main elif [ "${{ inputs.client-version }}" = "published" ]; then echo "Installing published llama-stack-client-python from PyPI" uv pip install llama-stack-client diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index a38d4971a..fc56f62ea 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -31,6 +31,14 @@ on: description: 'Test against a specific provider' type: string default: 'ollama' + test-subdirs: + description: 'Comma-separated list of test subdirectories to run' + type: string + default: '' + test-pattern: + description: 'Regex pattern to pass to pytest -k' + type: string + default: '' concurrency: # Skip concurrency for pushes to main - each commit should be tested independently @@ -38,27 +46,8 @@ concurrency: cancel-in-progress: true jobs: - discover-tests: - runs-on: ubuntu-latest - outputs: - test-types: ${{ steps.generate-test-types.outputs.test-types }} - - steps: - - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - - name: Generate test types - id: generate-test-types - run: | - # Get test directories dynamically, excluding non-test directories - # NOTE: we are excluding post_training since the tests take too long - TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d -printf "%f\n" | - grep -Ev "^(__pycache__|fixtures|test_cases|recordings|non_ci|post_training)$" | - sort | jq -R -s -c 'split("\n")[:-1]') - echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT run-replay-mode-tests: - needs: discover-tests runs-on: ubuntu-latest name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, vision={4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.run-vision-tests) }} @@ -89,7 +78,8 @@ jobs: - name: Run tests uses: ./.github/actions/run-and-record-tests with: - test-types: ${{ needs.discover-tests.outputs.test-types }} + test-subdirs: ${{ inputs.test-subdirs }} + test-pattern: ${{ inputs.test-pattern }} stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }} provider: ${{ matrix.provider }} inference-mode: 'replay' diff --git a/.github/workflows/integration-vector-io-tests.yml b/.github/workflows/integration-vector-io-tests.yml index aa239572b..99a44c147 100644 --- a/.github/workflows/integration-vector-io-tests.yml +++ b/.github/workflows/integration-vector-io-tests.yml @@ -14,9 +14,11 @@ on: - 'pyproject.toml' - 'requirements.txt' - '.github/workflows/integration-vector-io-tests.yml' # This workflow + schedule: + - cron: '0 0 * * *' # (test on python 3.13) Daily at 12 AM UTC concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: @@ -25,7 +27,7 @@ jobs: strategy: matrix: vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::weaviate", "remote::qdrant"] - python-version: ["3.12", "3.13"] + python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }} fail-fast: false # we want to run all tests regardless of failure steps: @@ -164,9 +166,9 @@ jobs: ENABLE_WEAVIATE: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'true' || '' }} WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }} run: | - uv run pytest -sv --stack-config="inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \ + uv run pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \ tests/integration/vector_io \ - --embedding-model sentence-transformers/all-MiniLM-L6-v2 + --embedding-model inline::sentence-transformers/all-MiniLM-L6-v2 - name: Check Storage and Memory Available After Tests if: ${{ always() }} diff --git a/.github/workflows/record-integration-tests.yml b/.github/workflows/record-integration-tests.yml index 12957db27..22636f209 100644 --- a/.github/workflows/record-integration-tests.yml +++ b/.github/workflows/record-integration-tests.yml @@ -1,93 +1,53 @@ +# This workflow should be run manually when needing to re-record tests. This happens when you have +# - added a new test +# - or changed an existing test such that a new inference call is made +# You should make a PR and then run this workflow on that PR branch. The workflow will re-record the +# tests and commit the recordings to the PR branch. name: Integration Tests (Record) run-name: Run the integration test suite from tests/integration on: - pull_request: - branches: [ main ] - types: [opened, synchronize, labeled] - paths: - - 'llama_stack/**' - - 'tests/**' - - 'uv.lock' - - 'pyproject.toml' - - '.github/workflows/record-integration-tests.yml' # This workflow - - '.github/actions/setup-ollama/action.yml' - - '.github/actions/setup-test-environment/action.yml' - - '.github/actions/run-and-record-tests/action.yml' workflow_dispatch: inputs: + test-subdirs: + description: 'Comma-separated list of test subdirectories to run' + type: string + default: '' test-provider: description: 'Test against a specific provider' type: string default: 'ollama' - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true + run-vision-tests: + description: 'Whether to run vision tests' + type: boolean + default: false + test-pattern: + description: 'Regex pattern to pass to pytest -k' + type: string + default: '' jobs: - discover-tests: - if: contains(github.event.pull_request.labels.*.name, 're-record-tests') || - contains(github.event.pull_request.labels.*.name, 're-record-vision-tests') - runs-on: ubuntu-latest - outputs: - test-types: ${{ steps.generate-test-types.outputs.test-types }} - matrix-modes: ${{ steps.generate-test-types.outputs.matrix-modes }} - - steps: - - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - - name: Generate test types - id: generate-test-types - run: | - # Get test directories dynamically, excluding non-test directories - TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d -printf "%f\n" | - grep -Ev "^(__pycache__|fixtures|test_cases|recordings|post_training)$" | - sort | jq -R -s -c 'split("\n")[:-1]') - echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT - - labels=$(gh pr view ${{ github.event.pull_request.number }} --json labels --jq '.labels[].name') - echo "labels=$labels" - - modes_array=() - if [[ $labels == *"re-record-vision-tests"* ]]; then - modes_array+=("vision") - fi - if [[ $labels == *"re-record-tests"* ]]; then - modes_array+=("non-vision") - fi - - # Convert to JSON array - if [ ${#modes_array[@]} -eq 0 ]; then - matrix_modes="[]" - else - matrix_modes=$(printf '%s\n' "${modes_array[@]}" | jq -R -s -c 'split("\n")[:-1]') - fi - echo "matrix_modes=$matrix_modes" - echo "matrix-modes=$matrix_modes" >> $GITHUB_OUTPUT - - env: - GH_TOKEN: ${{ github.token }} - record-tests: - needs: discover-tests runs-on: ubuntu-latest permissions: contents: write - strategy: - fail-fast: false - matrix: - mode: ${{ fromJSON(needs.discover-tests.outputs.matrix-modes) }} - steps: + - name: Echo workflow inputs + run: | + echo "::group::Workflow Inputs" + echo "test-subdirs: ${{ inputs.test-subdirs }}" + echo "test-provider: ${{ inputs.test-provider }}" + echo "run-vision-tests: ${{ inputs.run-vision-tests }}" + echo "test-pattern: ${{ inputs.test-pattern }}" + echo "branch: ${{ github.ref_name }}" + echo "::endgroup::" + - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: - ref: ${{ github.event.pull_request.head.ref }} fetch-depth: 0 - name: Setup test environment @@ -96,14 +56,15 @@ jobs: python-version: "3.12" # Use single Python version for recording client-version: "latest" provider: ${{ inputs.test-provider || 'ollama' }} - run-vision-tests: ${{ matrix.mode == 'vision' && 'true' || 'false' }} + run-vision-tests: ${{ inputs.run-vision-tests }} inference-mode: 'record' - name: Run and record tests uses: ./.github/actions/run-and-record-tests with: - test-types: ${{ needs.discover-tests.outputs.test-types }} + test-pattern: ${{ inputs.test-pattern }} + test-subdirs: ${{ inputs.test-subdirs }} stack-config: 'server:ci-tests' # recording must be done with server since more tests are run provider: ${{ inputs.test-provider || 'ollama' }} inference-mode: 'record' - run-vision-tests: ${{ matrix.mode == 'vision' && 'true' || 'false' }} + run-vision-tests: ${{ inputs.run-vision-tests }} diff --git a/.github/workflows/semantic-pr.yml b/.github/workflows/semantic-pr.yml index 4df7324c4..57a4df646 100644 --- a/.github/workflows/semantic-pr.yml +++ b/.github/workflows/semantic-pr.yml @@ -11,7 +11,7 @@ on: - synchronize concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event.pull_request.number }} cancel-in-progress: true permissions: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 30843173c..4309f289a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,6 +2,7 @@ exclude: 'build/' default_language_version: python: python3.12 + node: "22" repos: - repo: https://github.com/pre-commit/pre-commit-hooks @@ -145,6 +146,20 @@ repos: pass_filenames: false require_serial: true files: ^.github/workflows/.*$ + - id: ui-prettier + name: Format UI code with Prettier + entry: bash -c 'cd llama_stack/ui && npm run format' + language: system + files: ^llama_stack/ui/.*\.(ts|tsx)$ + pass_filenames: false + require_serial: true + - id: ui-eslint + name: Lint UI code with ESLint + entry: bash -c 'cd llama_stack/ui && npm run lint -- --fix --quiet' + language: system + files: ^llama_stack/ui/.*\.(ts|tsx)$ + pass_filenames: false + require_serial: true ci: autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 066fcecf0..c81e9e7b1 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,13 +1,82 @@ -# Contributing to Llama-Stack +# Contributing to Llama Stack We want to make contributing to this project as easy and transparent as possible. +## Set up your development environment + +We use [uv](https://github.com/astral-sh/uv) to manage python dependencies and virtual environments. +You can install `uv` by following this [guide](https://docs.astral.sh/uv/getting-started/installation/). + +You can install the dependencies by running: + +```bash +cd llama-stack +uv sync --group dev +uv pip install -e . +source .venv/bin/activate +``` + +```{note} +You can use a specific version of Python with `uv` by adding the `--python ` flag (e.g. `--python 3.12`). +Otherwise, `uv` will automatically select a Python version according to the `requires-python` section of the `pyproject.toml`. +For more info, see the [uv docs around Python versions](https://docs.astral.sh/uv/concepts/python-versions/). +``` + +Note that you can create a dotenv file `.env` that includes necessary environment variables: +``` +LLAMA_STACK_BASE_URL=http://localhost:8321 +LLAMA_STACK_CLIENT_LOG=debug +LLAMA_STACK_PORT=8321 +LLAMA_STACK_CONFIG= +TAVILY_SEARCH_API_KEY= +BRAVE_SEARCH_API_KEY= +``` + +And then use this dotenv file when running client SDK tests via the following: +```bash +uv run --env-file .env -- pytest -v tests/integration/inference/test_text_inference.py --text-model=meta-llama/Llama-3.1-8B-Instruct +``` + +### Pre-commit Hooks + +We use [pre-commit](https://pre-commit.com/) to run linting and formatting checks on your code. You can install the pre-commit hooks by running: + +```bash +uv run pre-commit install +``` + +After that, pre-commit hooks will run automatically before each commit. + +Alternatively, if you don't want to install the pre-commit hooks, you can run the checks manually by running: + +```bash +uv run pre-commit run --all-files +``` + +```{caution} +Before pushing your changes, make sure that the pre-commit hooks have passed successfully. +``` + ## Discussions -> Issues -> Pull Requests We actively welcome your pull requests. However, please read the following. This is heavily inspired by [Ghostty](https://github.com/ghostty-org/ghostty/blob/main/CONTRIBUTING.md). If in doubt, please open a [discussion](https://github.com/meta-llama/llama-stack/discussions); we can always convert that to an issue later. +### Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Meta has a [bounty program](http://facebook.com/whitehat/info) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +### Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Meta's open source projects. + +Complete your CLA here: + **I'd like to contribute!** If you are new to the project, start by looking at the issues tagged with "good first issue". If you're interested @@ -51,93 +120,15 @@ Please avoid picking up too many issues at once. This helps you stay focused and Please keep pull requests (PRs) small and focused. If you have a large set of changes, consider splitting them into logically grouped, smaller PRs to facilitate review and testing. -> [!TIP] -> As a general guideline: -> - Experienced contributors should try to keep no more than 5 open PRs at a time. -> - New contributors are encouraged to have only one open PR at a time until they’re familiar with the codebase and process. - -## Contributor License Agreement ("CLA") -In order to accept your pull request, we need you to submit a CLA. You only need -to do this once to work on any of Meta's open source projects. - -Complete your CLA here: - -## Issues -We use GitHub issues to track public bugs. Please ensure your description is -clear and has sufficient instructions to be able to reproduce the issue. - -Meta has a [bounty program](http://facebook.com/whitehat/info) for the safe -disclosure of security bugs. In those cases, please go through the process -outlined on that page and do not file a public issue. - - -## Set up your development environment - -We use [uv](https://github.com/astral-sh/uv) to manage python dependencies and virtual environments. -You can install `uv` by following this [guide](https://docs.astral.sh/uv/getting-started/installation/). - -You can install the dependencies by running: - -```bash -cd llama-stack -uv sync --group dev -uv pip install -e . -source .venv/bin/activate +```{tip} +As a general guideline: +- Experienced contributors should try to keep no more than 5 open PRs at a time. +- New contributors are encouraged to have only one open PR at a time until they’re familiar with the codebase and process. ``` -> [!NOTE] -> You can use a specific version of Python with `uv` by adding the `--python ` flag (e.g. `--python 3.12`) -> Otherwise, `uv` will automatically select a Python version according to the `requires-python` section of the `pyproject.toml`. -> For more info, see the [uv docs around Python versions](https://docs.astral.sh/uv/concepts/python-versions/). +## Repository guidelines -Note that you can create a dotenv file `.env` that includes necessary environment variables: -``` -LLAMA_STACK_BASE_URL=http://localhost:8321 -LLAMA_STACK_CLIENT_LOG=debug -LLAMA_STACK_PORT=8321 -LLAMA_STACK_CONFIG= -TAVILY_SEARCH_API_KEY= -BRAVE_SEARCH_API_KEY= -``` - -And then use this dotenv file when running client SDK tests via the following: -```bash -uv run --env-file .env -- pytest -v tests/integration/inference/test_text_inference.py --text-model=meta-llama/Llama-3.1-8B-Instruct -``` - -## Pre-commit Hooks - -We use [pre-commit](https://pre-commit.com/) to run linting and formatting checks on your code. You can install the pre-commit hooks by running: - -```bash -uv run pre-commit install -``` - -After that, pre-commit hooks will run automatically before each commit. - -Alternatively, if you don't want to install the pre-commit hooks, you can run the checks manually by running: - -```bash -uv run pre-commit run --all-files -``` - -> [!CAUTION] -> Before pushing your changes, make sure that the pre-commit hooks have passed successfully. - -## Running tests - -You can find the Llama Stack testing documentation [here](https://github.com/meta-llama/llama-stack/blob/main/tests/README.md). - -## Adding a new dependency to the project - -To add a new dependency to the project, you can use the `uv` command. For example, to add `foo` to the project, you can run: - -```bash -uv add foo -uv sync -``` - -## Coding Style +### Coding Style * Comments should provide meaningful insights into the code. Avoid filler comments that simply describe the next step, as they create unnecessary clutter, same goes for docstrings. @@ -159,6 +150,10 @@ uv sync * When possible, use keyword arguments only when calling functions. * Llama Stack utilizes [custom Exception classes](llama_stack/apis/common/errors.py) for certain Resources that should be used where applicable. +### License +By contributing to Llama, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. + ## Common Tasks Some tips about common tasks you work on while contributing to Llama Stack: @@ -210,8 +205,4 @@ If you modify or add new API endpoints, update the API documentation accordingly uv run ./docs/openapi_generator/run_openapi_generator.sh ``` -The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing. - -## License -By contributing to Llama, you agree that your contributions will be licensed -under the LICENSE file in the root directory of this source tree. +The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing. \ No newline at end of file diff --git a/README.md b/README.md index 8db4580a2..4df4a5372 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,5 @@ # Llama Stack -meta-llama%2Fllama-stack | Trendshift - ------ [![PyPI version](https://img.shields.io/pypi/v/llama_stack.svg)](https://pypi.org/project/llama_stack/) [![PyPI - Downloads](https://img.shields.io/pypi/dm/llama-stack)](https://pypi.org/project/llama-stack/) [![License](https://img.shields.io/pypi/l/llama_stack.svg)](https://github.com/meta-llama/llama-stack/blob/main/LICENSE) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index d480ff592..b36626719 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -8293,28 +8293,60 @@ "type": "array", "items": { "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" + "properties": { + "attributes": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } + "description": "(Optional) Key-value attributes associated with the file" + }, + "file_id": { + "type": "string", + "description": "Unique identifier of the file containing the result" + }, + "filename": { + "type": "string", + "description": "Name of the file containing the result" + }, + "score": { + "type": "number", + "description": "Relevance score for this search result (between 0 and 1)" + }, + "text": { + "type": "string", + "description": "Text content of the search result" + } + }, + "additionalProperties": false, + "required": [ + "attributes", + "file_id", + "filename", + "score", + "text" + ], + "title": "OpenAIResponseOutputMessageFileSearchToolCallResults", + "description": "Search results returned by the file search operation." }, "description": "(Optional) Search results returned by the file search operation" } @@ -8515,6 +8547,13 @@ "$ref": "#/components/schemas/OpenAIResponseInputTool" } }, + "include": { + "type": "array", + "items": { + "type": "string" + }, + "description": "(Optional) Additional fields to include in the response." + }, "max_infer_iters": { "type": "integer" } @@ -8782,6 +8821,61 @@ "title": "OpenAIResponseOutputMessageMCPListTools", "description": "MCP list tools output message containing available tools from an MCP server." }, + "OpenAIResponseContentPart": { + "oneOf": [ + { + "$ref": "#/components/schemas/OpenAIResponseContentPartOutputText" + }, + { + "$ref": "#/components/schemas/OpenAIResponseContentPartRefusal" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "output_text": "#/components/schemas/OpenAIResponseContentPartOutputText", + "refusal": "#/components/schemas/OpenAIResponseContentPartRefusal" + } + } + }, + "OpenAIResponseContentPartOutputText": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "output_text", + "default": "output_text" + }, + "text": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "type", + "text" + ], + "title": "OpenAIResponseContentPartOutputText" + }, + "OpenAIResponseContentPartRefusal": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "refusal", + "default": "refusal" + }, + "refusal": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "type", + "refusal" + ], + "title": "OpenAIResponseContentPartRefusal" + }, "OpenAIResponseObjectStream": { "oneOf": [ { @@ -8838,6 +8932,12 @@ { "$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallCompleted" }, + { + "$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseContentPartAdded" + }, + { + "$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseContentPartDone" + }, { "$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted" } @@ -8863,6 +8963,8 @@ "response.mcp_call.in_progress": "#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallInProgress", "response.mcp_call.failed": "#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallFailed", "response.mcp_call.completed": "#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallCompleted", + "response.content_part.added": "#/components/schemas/OpenAIResponseObjectStreamResponseContentPartAdded", + "response.content_part.done": "#/components/schemas/OpenAIResponseObjectStreamResponseContentPartDone", "response.completed": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted" } } @@ -8889,6 +8991,80 @@ "title": "OpenAIResponseObjectStreamResponseCompleted", "description": "Streaming event indicating a response has been completed." }, + "OpenAIResponseObjectStreamResponseContentPartAdded": { + "type": "object", + "properties": { + "response_id": { + "type": "string", + "description": "Unique identifier of the response containing this content" + }, + "item_id": { + "type": "string", + "description": "Unique identifier of the output item containing this content part" + }, + "part": { + "$ref": "#/components/schemas/OpenAIResponseContentPart", + "description": "The content part that was added" + }, + "sequence_number": { + "type": "integer", + "description": "Sequential number for ordering streaming events" + }, + "type": { + "type": "string", + "const": "response.content_part.added", + "default": "response.content_part.added", + "description": "Event type identifier, always \"response.content_part.added\"" + } + }, + "additionalProperties": false, + "required": [ + "response_id", + "item_id", + "part", + "sequence_number", + "type" + ], + "title": "OpenAIResponseObjectStreamResponseContentPartAdded", + "description": "Streaming event for when a new content part is added to a response item." + }, + "OpenAIResponseObjectStreamResponseContentPartDone": { + "type": "object", + "properties": { + "response_id": { + "type": "string", + "description": "Unique identifier of the response containing this content" + }, + "item_id": { + "type": "string", + "description": "Unique identifier of the output item containing this content part" + }, + "part": { + "$ref": "#/components/schemas/OpenAIResponseContentPart", + "description": "The completed content part" + }, + "sequence_number": { + "type": "integer", + "description": "Sequential number for ordering streaming events" + }, + "type": { + "type": "string", + "const": "response.content_part.done", + "default": "response.content_part.done", + "description": "Event type identifier, always \"response.content_part.done\"" + } + }, + "additionalProperties": false, + "required": [ + "response_id", + "item_id", + "part", + "sequence_number", + "type" + ], + "title": "OpenAIResponseObjectStreamResponseContentPartDone", + "description": "Streaming event for when a content part is completed." + }, "OpenAIResponseObjectStreamResponseCreated": { "type": "object", "properties": { @@ -14591,7 +14767,8 @@ "OpenAIFilePurpose": { "type": "string", "enum": [ - "assistants" + "assistants", + "batch" ], "title": "OpenAIFilePurpose", "description": "Valid purpose values for OpenAI Files API." @@ -14668,7 +14845,8 @@ "purpose": { "type": "string", "enum": [ - "assistants" + "assistants", + "batch" ], "description": "The intended purpose of the file" } @@ -16530,7 +16708,7 @@ "additionalProperties": { "type": "number" }, - "description": "A list of the categories along with their scores as predicted by model. Required set of categories that need to be in response - violence - violence/graphic - harassment - harassment/threatening - hate - hate/threatening - illicit - illicit/violent - sexual - sexual/minors - self-harm - self-harm/intent - self-harm/instructions" + "description": "A list of the categories along with their scores as predicted by model." }, "user_message": { "type": "string" diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 9c0fba554..e7733b3c3 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -6021,14 +6021,44 @@ components: type: array items: type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object + properties: + attributes: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: >- + (Optional) Key-value attributes associated with the file + file_id: + type: string + description: >- + Unique identifier of the file containing the result + filename: + type: string + description: Name of the file containing the result + score: + type: number + description: >- + Relevance score for this search result (between 0 and 1) + text: + type: string + description: Text content of the search result + additionalProperties: false + required: + - attributes + - file_id + - filename + - score + - text + title: >- + OpenAIResponseOutputMessageFileSearchToolCallResults + description: >- + Search results returned by the file search operation. description: >- (Optional) Search results returned by the file search operation additionalProperties: false @@ -6188,6 +6218,12 @@ components: type: array items: $ref: '#/components/schemas/OpenAIResponseInputTool' + include: + type: array + items: + type: string + description: >- + (Optional) Additional fields to include in the response. max_infer_iters: type: integer additionalProperties: false @@ -6405,6 +6441,43 @@ components: title: OpenAIResponseOutputMessageMCPListTools description: >- MCP list tools output message containing available tools from an MCP server. + OpenAIResponseContentPart: + oneOf: + - $ref: '#/components/schemas/OpenAIResponseContentPartOutputText' + - $ref: '#/components/schemas/OpenAIResponseContentPartRefusal' + discriminator: + propertyName: type + mapping: + output_text: '#/components/schemas/OpenAIResponseContentPartOutputText' + refusal: '#/components/schemas/OpenAIResponseContentPartRefusal' + OpenAIResponseContentPartOutputText: + type: object + properties: + type: + type: string + const: output_text + default: output_text + text: + type: string + additionalProperties: false + required: + - type + - text + title: OpenAIResponseContentPartOutputText + OpenAIResponseContentPartRefusal: + type: object + properties: + type: + type: string + const: refusal + default: refusal + refusal: + type: string + additionalProperties: false + required: + - type + - refusal + title: OpenAIResponseContentPartRefusal OpenAIResponseObjectStream: oneOf: - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated' @@ -6425,6 +6498,8 @@ components: - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallInProgress' - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallFailed' - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallCompleted' + - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseContentPartAdded' + - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseContentPartDone' - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted' discriminator: propertyName: type @@ -6447,6 +6522,8 @@ components: response.mcp_call.in_progress: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallInProgress' response.mcp_call.failed: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallFailed' response.mcp_call.completed: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallCompleted' + response.content_part.added: '#/components/schemas/OpenAIResponseObjectStreamResponseContentPartAdded' + response.content_part.done: '#/components/schemas/OpenAIResponseObjectStreamResponseContentPartDone' response.completed: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted' "OpenAIResponseObjectStreamResponseCompleted": type: object @@ -6468,6 +6545,76 @@ components: OpenAIResponseObjectStreamResponseCompleted description: >- Streaming event indicating a response has been completed. + "OpenAIResponseObjectStreamResponseContentPartAdded": + type: object + properties: + response_id: + type: string + description: >- + Unique identifier of the response containing this content + item_id: + type: string + description: >- + Unique identifier of the output item containing this content part + part: + $ref: '#/components/schemas/OpenAIResponseContentPart' + description: The content part that was added + sequence_number: + type: integer + description: >- + Sequential number for ordering streaming events + type: + type: string + const: response.content_part.added + default: response.content_part.added + description: >- + Event type identifier, always "response.content_part.added" + additionalProperties: false + required: + - response_id + - item_id + - part + - sequence_number + - type + title: >- + OpenAIResponseObjectStreamResponseContentPartAdded + description: >- + Streaming event for when a new content part is added to a response item. + "OpenAIResponseObjectStreamResponseContentPartDone": + type: object + properties: + response_id: + type: string + description: >- + Unique identifier of the response containing this content + item_id: + type: string + description: >- + Unique identifier of the output item containing this content part + part: + $ref: '#/components/schemas/OpenAIResponseContentPart' + description: The completed content part + sequence_number: + type: integer + description: >- + Sequential number for ordering streaming events + type: + type: string + const: response.content_part.done + default: response.content_part.done + description: >- + Event type identifier, always "response.content_part.done" + additionalProperties: false + required: + - response_id + - item_id + - part + - sequence_number + - type + title: >- + OpenAIResponseObjectStreamResponseContentPartDone + description: >- + Streaming event for when a content part is completed. "OpenAIResponseObjectStreamResponseCreated": type: object properties: @@ -10804,6 +10951,7 @@ components: type: string enum: - assistants + - batch title: OpenAIFilePurpose description: >- Valid purpose values for OpenAI Files API. @@ -10872,6 +11020,7 @@ components: type: string enum: - assistants + - batch description: The intended purpose of the file additionalProperties: false required: @@ -12286,10 +12435,6 @@ components: type: number description: >- A list of the categories along with their scores as predicted by model. - Required set of categories that need to be in response - violence - violence/graphic - - harassment - harassment/threatening - hate - hate/threatening - illicit - - illicit/violent - sexual - sexual/minors - self-harm - self-harm/intent - - self-harm/instructions user_message: type: string metadata: diff --git a/docs/source/apis/external.md b/docs/source/apis/external.md index cc13deb9b..5831990b0 100644 --- a/docs/source/apis/external.md +++ b/docs/source/apis/external.md @@ -111,7 +111,7 @@ name = "llama-stack-api-weather" version = "0.1.0" description = "Weather API for Llama Stack" readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.12" dependencies = ["llama-stack", "pydantic"] [build-system] @@ -231,7 +231,7 @@ name = "llama-stack-provider-kaze" version = "0.1.0" description = "Kaze weather provider for Llama Stack" readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.12" dependencies = ["llama-stack", "pydantic", "aiohttp"] [build-system] diff --git a/docs/source/building_applications/responses_vs_agents.md b/docs/source/building_applications/responses_vs_agents.md index 3eebfb460..5abe951d6 100644 --- a/docs/source/building_applications/responses_vs_agents.md +++ b/docs/source/building_applications/responses_vs_agents.md @@ -2,7 +2,9 @@ Llama Stack (LLS) provides two different APIs for building AI applications with tool calling capabilities: the **Agents API** and the **OpenAI Responses API**. While both enable AI systems to use tools, and maintain full conversation history, they serve different use cases and have distinct characteristics. -> **Note:** For simple and basic inferencing, you may want to use the [Chat Completions API](https://llama-stack.readthedocs.io/en/latest/providers/index.html#chat-completions) directly, before progressing to Agents or Responses API. +```{note} +For simple and basic inferencing, you may want to use the [Chat Completions API](https://llama-stack.readthedocs.io/en/latest/providers/index.html#chat-completions) directly, before progressing to Agents or Responses API. +``` ## Overview diff --git a/docs/source/building_applications/tools.md b/docs/source/building_applications/tools.md index b19be888c..8a54290ed 100644 --- a/docs/source/building_applications/tools.md +++ b/docs/source/building_applications/tools.md @@ -76,7 +76,9 @@ Features: - Context retrieval with token limits -> **Note:** By default, llama stack run.yaml defines toolgroups for web search, wolfram alpha and rag, that are provided by tavily-search, wolfram-alpha and rag providers. +```{note} +By default, llama stack run.yaml defines toolgroups for web search, wolfram alpha and rag, that are provided by tavily-search, wolfram-alpha and rag providers. +``` ## Model Context Protocol (MCP) diff --git a/docs/source/concepts/apis.md b/docs/source/concepts/apis.md index 5a10d6498..f8f73a928 100644 --- a/docs/source/concepts/apis.md +++ b/docs/source/concepts/apis.md @@ -18,3 +18,4 @@ We are working on adding a few more APIs to complete the application lifecycle. - **Batch Inference**: run inference on a dataset of inputs - **Batch Agents**: run agents on a dataset of inputs - **Synthetic Data Generation**: generate synthetic data for model development +- **Batches**: OpenAI-compatible batch management for inference diff --git a/docs/source/contributing/index.md b/docs/source/contributing/index.md index 79c3861ea..24bf3f66c 100644 --- a/docs/source/contributing/index.md +++ b/docs/source/contributing/index.md @@ -2,24 +2,13 @@ ```{include} ../../../CONTRIBUTING.md ``` -## Testing - -See the [Test Page](testing.md) which describes how to test your changes. -```{toctree} -:maxdepth: 1 -:hidden: -:caption: Testing - -testing -``` - ## Adding a New Provider -See the [Adding a New API Provider Page](new_api_provider.md) which describes how to add new API providers to the Stack. +See: +- [Adding a New API Provider Page](new_api_provider.md) which describes how to add new API providers to the Stack. +- [Vector Database Page](new_vector_database.md) which describes how to add a new vector databases with Llama Stack. +- [External Provider Page](../providers/external/index.md) which describes how to add external providers to the Stack. -See the [Vector Database Page](new_vector_database.md) which describes how to add a new vector databases with Llama Stack. - -See the [External Provider Page](../providers/external/index.md) which describes how to add external providers to the Stack. ```{toctree} :maxdepth: 1 :hidden: @@ -27,3 +16,24 @@ See the [External Provider Page](../providers/external/index.md) which describes new_api_provider new_vector_database ``` + +## Testing + + +```{include} ../../../tests/README.md +``` + +## Benchmarking + +```{include} ../../../docs/source/distributions/k8s-benchmark/README.md +``` + +### Advanced Topics + +For developers who need deeper understanding of the testing system internals: + +```{toctree} +:maxdepth: 1 + +testing/record-replay +``` diff --git a/docs/source/contributing/testing.md b/docs/source/contributing/testing.md deleted file mode 100644 index 454ded266..000000000 --- a/docs/source/contributing/testing.md +++ /dev/null @@ -1,8 +0,0 @@ -```{include} ../../../tests/README.md -``` - -```{include} ../../../tests/unit/README.md -``` - -```{include} ../../../tests/integration/README.md -``` diff --git a/docs/source/contributing/testing/record-replay.md b/docs/source/contributing/testing/record-replay.md new file mode 100644 index 000000000..3049d333c --- /dev/null +++ b/docs/source/contributing/testing/record-replay.md @@ -0,0 +1,234 @@ +# Record-Replay System + +Understanding how Llama Stack captures and replays API interactions for testing. + +## Overview + +The record-replay system solves a fundamental challenge in AI testing: how do you test against expensive, non-deterministic APIs without breaking the bank or dealing with flaky tests? + +The solution: intercept API calls, store real responses, and replay them later. This gives you real API behavior without the cost or variability. + +## How It Works + +### Request Hashing + +Every API request gets converted to a deterministic hash for lookup: + +```python +def normalize_request(method: str, url: str, headers: dict, body: dict) -> str: + normalized = { + "method": method.upper(), + "endpoint": urlparse(url).path, # Just the path, not full URL + "body": body, # Request parameters + } + return hashlib.sha256(json.dumps(normalized, sort_keys=True).encode()).hexdigest() +``` + +**Key insight:** The hashing is intentionally precise. Different whitespace, float precision, or parameter order produces different hashes. This prevents subtle bugs from false cache hits. + +```python +# These produce DIFFERENT hashes: +{"content": "Hello world"} +{"content": "Hello world\n"} +{"temperature": 0.7} +{"temperature": 0.7000001} +``` + +### Client Interception + +The system patches OpenAI and Ollama client methods to intercept calls before they leave your application. This happens transparently - your test code doesn't change. + +### Storage Architecture + +Recordings use a two-tier storage system optimized for both speed and debuggability: + +``` +recordings/ +β”œβ”€β”€ index.sqlite # Fast lookup by request hash +└── responses/ + β”œβ”€β”€ abc123def456.json # Individual response files + └── def789ghi012.json +``` + +**SQLite index** enables O(log n) hash lookups and metadata queries without loading response bodies. + +**JSON files** store complete request/response pairs in human-readable format for debugging. + +## Recording Modes + +### LIVE Mode + +Direct API calls with no recording or replay: + +```python +with inference_recording(mode=InferenceMode.LIVE): + response = await client.chat.completions.create(...) +``` + +Use for initial development and debugging against real APIs. + +### RECORD Mode + +Captures API interactions while passing through real responses: + +```python +with inference_recording(mode=InferenceMode.RECORD, storage_dir="./recordings"): + response = await client.chat.completions.create(...) + # Real API call made, response captured AND returned +``` + +The recording process: +1. Request intercepted and hashed +2. Real API call executed +3. Response captured and serialized +4. Recording stored to disk +5. Original response returned to caller + +### REPLAY Mode + +Returns stored responses instead of making API calls: + +```python +with inference_recording(mode=InferenceMode.REPLAY, storage_dir="./recordings"): + response = await client.chat.completions.create(...) + # No API call made, cached response returned instantly +``` + +The replay process: +1. Request intercepted and hashed +2. Hash looked up in SQLite index +3. Response loaded from JSON file +4. Response deserialized and returned +5. Error if no recording found + +## Streaming Support + +Streaming APIs present a unique challenge: how do you capture an async generator? + +### The Problem + +```python +# How do you record this? +async for chunk in client.chat.completions.create(stream=True): + process(chunk) +``` + +### The Solution + +The system captures all chunks immediately before yielding any: + +```python +async def handle_streaming_record(response): + # Capture complete stream first + chunks = [] + async for chunk in response: + chunks.append(chunk) + + # Store complete recording + storage.store_recording( + request_hash, request_data, {"body": chunks, "is_streaming": True} + ) + + # Return generator that replays captured chunks + async def replay_stream(): + for chunk in chunks: + yield chunk + + return replay_stream() +``` + +This ensures: +- **Complete capture** - The entire stream is saved atomically +- **Interface preservation** - The returned object behaves like the original API +- **Deterministic replay** - Same chunks in the same order every time + +## Serialization + +API responses contain complex Pydantic objects that need careful serialization: + +```python +def _serialize_response(response): + if hasattr(response, "model_dump"): + # Preserve type information for proper deserialization + return { + "__type__": f"{response.__class__.__module__}.{response.__class__.__qualname__}", + "__data__": response.model_dump(mode="json"), + } + return response +``` + +This preserves type safety - when replayed, you get the same Pydantic objects with all their validation and methods. + +## Environment Integration + +### Environment Variables + +Control recording behavior globally: + +```bash +export LLAMA_STACK_TEST_INFERENCE_MODE=replay +export LLAMA_STACK_TEST_RECORDING_DIR=/path/to/recordings +pytest tests/integration/ +``` + +### Pytest Integration + +The system integrates automatically based on environment variables, requiring no changes to test code. + +## Debugging Recordings + +### Inspecting Storage + +```bash +# See what's recorded +sqlite3 recordings/index.sqlite "SELECT endpoint, model, timestamp FROM recordings LIMIT 10;" + +# View specific response +cat recordings/responses/abc123def456.json | jq '.response.body' + +# Find recordings by endpoint +sqlite3 recordings/index.sqlite "SELECT * FROM recordings WHERE endpoint='/v1/chat/completions';" +``` + +### Common Issues + +**Hash mismatches:** Request parameters changed slightly between record and replay +```bash +# Compare request details +cat recordings/responses/abc123.json | jq '.request' +``` + +**Serialization errors:** Response types changed between versions +```bash +# Re-record with updated types +rm recordings/responses/failing_hash.json +LLAMA_STACK_TEST_INFERENCE_MODE=record pytest test_failing.py +``` + +**Missing recordings:** New test or changed parameters +```bash +# Record the missing interaction +LLAMA_STACK_TEST_INFERENCE_MODE=record pytest test_new.py +``` + +## Design Decisions + +### Why Not Mocks? + +Traditional mocking breaks down with AI APIs because: +- Response structures are complex and evolve frequently +- Streaming behavior is hard to mock correctly +- Edge cases in real APIs get missed +- Mocks become brittle maintenance burdens + +### Why Precise Hashing? + +Loose hashing (normalizing whitespace, rounding floats) seems convenient but hides bugs. If a test changes slightly, you want to know about it rather than accidentally getting the wrong cached response. + +### Why JSON + SQLite? + +- **JSON** - Human readable, diff-friendly, easy to inspect and modify +- **SQLite** - Fast indexed lookups without loading response bodies +- **Hybrid** - Best of both worlds for different use cases + +This system provides reliable, fast testing against real AI APIs while maintaining the ability to debug issues when they arise. \ No newline at end of file diff --git a/docs/source/distributions/k8s-benchmark/README.md b/docs/source/distributions/k8s-benchmark/README.md new file mode 100644 index 000000000..42da4d466 --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/README.md @@ -0,0 +1,156 @@ +# Llama Stack Benchmark Suite on Kubernetes + +## Motivation + +Performance benchmarking is critical for understanding the overhead and characteristics of the Llama Stack abstraction layer compared to direct inference engines like vLLM. + +### Why This Benchmark Suite Exists + +**Performance Validation**: The Llama Stack provides a unified API layer across multiple inference providers, but this abstraction introduces potential overhead. This benchmark suite quantifies the performance impact by comparing: +- Llama Stack inference (with vLLM backend) +- Direct vLLM inference calls +- Both under identical Kubernetes deployment conditions + +**Production Readiness Assessment**: Real-world deployments require understanding performance characteristics under load. This suite simulates concurrent user scenarios with configurable parameters (duration, concurrency, request patterns) to validate production readiness. + +**Regression Detection (TODO)**: As the Llama Stack evolves, this benchmark provides automated regression detection for performance changes. CI/CD pipelines can leverage these benchmarks to catch performance degradations before production deployments. + +**Resource Planning**: By measuring throughput, latency percentiles, and resource utilization patterns, teams can make informed decisions about: +- Kubernetes resource allocation (CPU, memory, GPU) +- Auto-scaling configurations +- Cost optimization strategies + +### Key Metrics Captured + +The benchmark suite measures critical performance indicators: +- **Throughput**: Requests per second under sustained load +- **Latency Distribution**: P50, P95, P99 response times +- **Time to First Token (TTFT)**: Critical for streaming applications +- **Error Rates**: Request failures and timeout analysis + +This data enables data-driven architectural decisions and performance optimization efforts. + +## Setup + +**1. Deploy base k8s infrastructure:** +```bash +cd ../k8s +./apply.sh +``` + +**2. Deploy benchmark components:** +```bash +cd ../k8s-benchmark +./apply.sh +``` + +**3. Verify deployment:** +```bash +kubectl get pods +# Should see: llama-stack-benchmark-server, vllm-server, etc. +``` + +## Quick Start + +### Basic Benchmarks + +**Benchmark Llama Stack (default):** +```bash +cd docs/source/distributions/k8s-benchmark/ +./run-benchmark.sh +``` + +**Benchmark vLLM direct:** +```bash +./run-benchmark.sh --target vllm +``` + +### Custom Configuration + +**Extended benchmark with high concurrency:** +```bash +./run-benchmark.sh --target vllm --duration 120 --concurrent 20 +``` + +**Short test run:** +```bash +./run-benchmark.sh --target stack --duration 30 --concurrent 5 +``` + +## Command Reference + +### run-benchmark.sh Options + +```bash +./run-benchmark.sh [options] + +Options: + -t, --target Target to benchmark (default: stack) + -d, --duration Duration in seconds (default: 60) + -c, --concurrent Number of concurrent users (default: 10) + -h, --help Show help message + +Examples: + ./run-benchmark.sh --target vllm # Benchmark vLLM direct + ./run-benchmark.sh --target stack # Benchmark Llama Stack + ./run-benchmark.sh -t vllm -d 120 -c 20 # vLLM with 120s, 20 users +``` + +## Local Testing + +### Running Benchmark Locally + +For local development without Kubernetes: + +**1. Start OpenAI mock server:** +```bash +uv run python openai-mock-server.py --port 8080 +``` + +**2. Run benchmark against mock server:** +```bash +uv run python benchmark.py \ + --base-url http://localhost:8080/v1 \ + --model mock-inference \ + --duration 30 \ + --concurrent 5 +``` + +**3. Test against local vLLM server:** +```bash +# If you have vLLM running locally on port 8000 +uv run python benchmark.py \ + --base-url http://localhost:8000/v1 \ + --model meta-llama/Llama-3.2-3B-Instruct \ + --duration 30 \ + --concurrent 5 +``` + +**4. Profile the running server:** +```bash +./profile_running_server.sh +``` + + + +### OpenAI Mock Server + +The `openai-mock-server.py` provides: +- **OpenAI-compatible API** for testing without real models +- **Configurable streaming delay** via `STREAM_DELAY_SECONDS` env var +- **Consistent responses** for reproducible benchmarks +- **Lightweight testing** without GPU requirements + +**Mock server usage:** +```bash +uv run python openai-mock-server.py --port 8080 +``` + +The mock server is also deployed in k8s as `openai-mock-service:8080` and can be used by changing the Llama Stack configuration to use the `mock-vllm-inference` provider. + +## Files in this Directory + +- `benchmark.py` - Core benchmark script with async streaming support +- `run-benchmark.sh` - Main script with target selection and configuration +- `openai-mock-server.py` - Mock OpenAI API server for local testing +- `README.md` - This documentation file diff --git a/docs/source/distributions/k8s-benchmark/apply.sh b/docs/source/distributions/k8s-benchmark/apply.sh new file mode 100755 index 000000000..4f2270da8 --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/apply.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Deploys the benchmark-specific components on top of the base k8s deployment (../k8s/apply.sh). + +export STREAM_DELAY_SECONDS=0.005 + +export POSTGRES_USER=llamastack +export POSTGRES_DB=llamastack +export POSTGRES_PASSWORD=llamastack + +export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct +export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B + +export MOCK_INFERENCE_MODEL=mock-inference + +export MOCK_INFERENCE_URL=openai-mock-service:8080 + +export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL + +set -euo pipefail +set -x + +# Deploy benchmark-specific components +kubectl create configmap llama-stack-config --from-file=stack_run_config.yaml \ + --dry-run=client -o yaml > stack-configmap.yaml + +kubectl apply --validate=false -f stack-configmap.yaml + +# Deploy our custom llama stack server (overriding the base one) +envsubst < stack-k8s.yaml.template | kubectl apply --validate=false -f - diff --git a/docs/source/distributions/k8s-benchmark/benchmark.py b/docs/source/distributions/k8s-benchmark/benchmark.py new file mode 100644 index 000000000..0e7368431 --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/benchmark.py @@ -0,0 +1,268 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Simple benchmark script for Llama Stack with OpenAI API compatibility. +""" + +import argparse +import asyncio +import os +import random +import statistics +import time +from typing import Tuple +import aiohttp + + +class BenchmarkStats: + def __init__(self): + self.response_times = [] + self.ttft_times = [] + self.chunks_received = [] + self.errors = [] + self.success_count = 0 + self.total_requests = 0 + self.concurrent_users = 0 + self.start_time = None + self.end_time = None + self._lock = asyncio.Lock() + + async def add_result(self, response_time: float, chunks: int, ttft: float = None, error: str = None): + async with self._lock: + self.total_requests += 1 + if error: + self.errors.append(error) + else: + self.success_count += 1 + self.response_times.append(response_time) + self.chunks_received.append(chunks) + if ttft is not None: + self.ttft_times.append(ttft) + + def print_summary(self): + if not self.response_times: + print("No successful requests to report") + if self.errors: + print(f"Total errors: {len(self.errors)}") + print("First 5 errors:") + for error in self.errors[:5]: + print(f" {error}") + return + + total_time = self.end_time - self.start_time + success_rate = (self.success_count / self.total_requests) * 100 + + print(f"\n{'='*60}") + print(f"BENCHMARK RESULTS") + print(f"{'='*60}") + print(f"Total time: {total_time:.2f}s") + print(f"Concurrent users: {self.concurrent_users}") + print(f"Total requests: {self.total_requests}") + print(f"Successful requests: {self.success_count}") + print(f"Failed requests: {len(self.errors)}") + print(f"Success rate: {success_rate:.1f}%") + print(f"Requests per second: {self.success_count / total_time:.2f}") + + print(f"\nResponse Time Statistics:") + print(f" Mean: {statistics.mean(self.response_times):.3f}s") + print(f" Median: {statistics.median(self.response_times):.3f}s") + print(f" Min: {min(self.response_times):.3f}s") + print(f" Max: {max(self.response_times):.3f}s") + + if len(self.response_times) > 1: + print(f" Std Dev: {statistics.stdev(self.response_times):.3f}s") + + percentiles = [50, 90, 95, 99] + sorted_times = sorted(self.response_times) + print(f"\nPercentiles:") + for p in percentiles: + idx = int(len(sorted_times) * p / 100) - 1 + idx = max(0, min(idx, len(sorted_times) - 1)) + print(f" P{p}: {sorted_times[idx]:.3f}s") + + if self.ttft_times: + print(f"\nTime to First Token (TTFT) Statistics:") + print(f" Mean: {statistics.mean(self.ttft_times):.3f}s") + print(f" Median: {statistics.median(self.ttft_times):.3f}s") + print(f" Min: {min(self.ttft_times):.3f}s") + print(f" Max: {max(self.ttft_times):.3f}s") + + if len(self.ttft_times) > 1: + print(f" Std Dev: {statistics.stdev(self.ttft_times):.3f}s") + + sorted_ttft = sorted(self.ttft_times) + print(f"\nTTFT Percentiles:") + for p in percentiles: + idx = int(len(sorted_ttft) * p / 100) - 1 + idx = max(0, min(idx, len(sorted_ttft) - 1)) + print(f" P{p}: {sorted_ttft[idx]:.3f}s") + + if self.chunks_received: + print(f"\nStreaming Statistics:") + print(f" Mean chunks per response: {statistics.mean(self.chunks_received):.1f}") + print(f" Total chunks received: {sum(self.chunks_received)}") + + if self.errors: + print(f"\nErrors (showing first 5):") + for error in self.errors[:5]: + print(f" {error}") + + +class LlamaStackBenchmark: + def __init__(self, base_url: str, model_id: str): + self.base_url = base_url.rstrip('/') + self.model_id = model_id + self.headers = {"Content-Type": "application/json"} + self.test_messages = [ + [{"role": "user", "content": "Hi"}], + [{"role": "user", "content": "What is the capital of France?"}], + [{"role": "user", "content": "Explain quantum physics in simple terms."}], + [{"role": "user", "content": "Write a short story about a robot learning to paint."}], + [ + {"role": "user", "content": "What is machine learning?"}, + {"role": "assistant", "content": "Machine learning is a subset of AI..."}, + {"role": "user", "content": "Can you give me a practical example?"} + ] + ] + + + async def make_async_streaming_request(self) -> Tuple[float, int, float | None, str | None]: + """Make a single async streaming chat completion request.""" + messages = random.choice(self.test_messages) + payload = { + "model": self.model_id, + "messages": messages, + "stream": True, + "max_tokens": 100 + } + + start_time = time.time() + chunks_received = 0 + ttft = None + error = None + + session = aiohttp.ClientSession() + + try: + async with session.post( + f"{self.base_url}/chat/completions", + headers=self.headers, + json=payload, + timeout=aiohttp.ClientTimeout(total=30) + ) as response: + if response.status == 200: + async for line in response.content: + if line: + line_str = line.decode('utf-8').strip() + if line_str.startswith('data: '): + chunks_received += 1 + if ttft is None: + ttft = time.time() - start_time + if line_str == 'data: [DONE]': + break + + if chunks_received == 0: + error = "No streaming chunks received" + else: + text = await response.text() + error = f"HTTP {response.status}: {text[:100]}" + + except Exception as e: + error = f"Request error: {str(e)}" + finally: + await session.close() + + response_time = time.time() - start_time + return response_time, chunks_received, ttft, error + + + async def run_benchmark(self, duration: int, concurrent_users: int) -> BenchmarkStats: + """Run benchmark using async requests for specified duration.""" + stats = BenchmarkStats() + stats.concurrent_users = concurrent_users + stats.start_time = time.time() + + print(f"Starting benchmark: {duration}s duration, {concurrent_users} concurrent users") + print(f"Target URL: {self.base_url}/chat/completions") + print(f"Model: {self.model_id}") + + connector = aiohttp.TCPConnector(limit=concurrent_users) + async with aiohttp.ClientSession(connector=connector) as session: + + async def worker(worker_id: int): + """Worker that sends requests sequentially until canceled.""" + request_count = 0 + while True: + try: + response_time, chunks, ttft, error = await self.make_async_streaming_request() + await stats.add_result(response_time, chunks, ttft, error) + request_count += 1 + + except asyncio.CancelledError: + break + except Exception as e: + await stats.add_result(0, 0, None, f"Worker {worker_id} error: {str(e)}") + + # Progress reporting task + async def progress_reporter(): + last_report_time = time.time() + while True: + try: + await asyncio.sleep(1) # Report every second + if time.time() >= last_report_time + 10: # Report every 10 seconds + elapsed = time.time() - stats.start_time + print(f"Completed: {stats.total_requests} requests in {elapsed:.1f}s") + last_report_time = time.time() + except asyncio.CancelledError: + break + + # Spawn concurrent workers + tasks = [asyncio.create_task(worker(i)) for i in range(concurrent_users)] + progress_task = asyncio.create_task(progress_reporter()) + tasks.append(progress_task) + + # Wait for duration then cancel all tasks + await asyncio.sleep(duration) + + for task in tasks: + task.cancel() + + # Wait for all tasks to complete + await asyncio.gather(*tasks, return_exceptions=True) + + stats.end_time = time.time() + return stats + + +def main(): + parser = argparse.ArgumentParser(description="Llama Stack Benchmark Tool") + parser.add_argument("--base-url", default=os.getenv("BENCHMARK_BASE_URL", "http://localhost:8000/v1/openai/v1"), + help="Base URL for the API (default: http://localhost:8000/v1/openai/v1)") + parser.add_argument("--model", default=os.getenv("INFERENCE_MODEL", "test-model"), + help="Model ID to use for requests") + parser.add_argument("--duration", type=int, default=60, + help="Duration in seconds to run benchmark (default: 60)") + parser.add_argument("--concurrent", type=int, default=10, + help="Number of concurrent users (default: 10)") + + args = parser.parse_args() + + benchmark = LlamaStackBenchmark(args.base_url, args.model) + + try: + stats = asyncio.run(benchmark.run_benchmark(args.duration, args.concurrent)) + stats.print_summary() + + except KeyboardInterrupt: + print("\nBenchmark interrupted by user") + except Exception as e: + print(f"Benchmark failed: {e}") + + +if __name__ == "__main__": + main() diff --git a/docs/source/distributions/k8s-benchmark/openai-mock-server.py b/docs/source/distributions/k8s-benchmark/openai-mock-server.py new file mode 100755 index 000000000..de0680842 --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/openai-mock-server.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +OpenAI-compatible mock server that returns: +- Hardcoded /models response for consistent validation +- Valid OpenAI-formatted chat completion responses with dynamic content +""" + +from flask import Flask, request, jsonify, Response +import time +import random +import uuid +import json +import argparse +import os + +app = Flask(__name__) + +# Models from environment variables +def get_models(): + models_str = os.getenv("MOCK_MODELS", "meta-llama/Llama-3.2-3B-Instruct") + model_ids = [m.strip() for m in models_str.split(",") if m.strip()] + + return { + "object": "list", + "data": [ + { + "id": model_id, + "object": "model", + "created": 1234567890, + "owned_by": "vllm" + } + for model_id in model_ids + ] + } + +def generate_random_text(length=50): + """Generate random but coherent text for responses.""" + words = [ + "Hello", "there", "I'm", "an", "AI", "assistant", "ready", "to", "help", "you", + "with", "your", "questions", "and", "tasks", "today", "Let", "me","know", "what", + "you'd", "like", "to", "discuss", "or", "explore", "together", "I", "can", "assist", + "with", "various", "topics", "including", "coding", "writing", "analysis", "and", "more" + ] + return " ".join(random.choices(words, k=length)) + +@app.route('/v1/models', methods=['GET']) +def list_models(): + models = get_models() + print(f"[MOCK] Returning models: {[m['id'] for m in models['data']]}") + return jsonify(models) + +@app.route('/v1/chat/completions', methods=['POST']) +def chat_completions(): + """Return OpenAI-formatted chat completion responses.""" + data = request.get_json() + default_model = get_models()['data'][0]['id'] + model = data.get('model', default_model) + messages = data.get('messages', []) + stream = data.get('stream', False) + + print(f"[MOCK] Chat completion request - model: {model}, stream: {stream}") + + if stream: + return handle_streaming_completion(model, messages) + else: + return handle_non_streaming_completion(model, messages) + +def handle_non_streaming_completion(model, messages): + response_text = generate_random_text(random.randint(20, 80)) + + # Calculate realistic token counts + prompt_tokens = sum(len(str(msg.get('content', '')).split()) for msg in messages) + completion_tokens = len(response_text.split()) + + response = { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": response_text + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens + } + } + + return jsonify(response) + +def handle_streaming_completion(model, messages): + def generate_stream(): + # Generate response text + full_response = generate_random_text(random.randint(30, 100)) + words = full_response.split() + + # Send initial chunk + initial_chunk = { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": ""} + } + ] + } + yield f"data: {json.dumps(initial_chunk)}\n\n" + + # Send word by word + for i, word in enumerate(words): + chunk = { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": {"content": f"{word} " if i < len(words) - 1 else word} + } + ] + } + yield f"data: {json.dumps(chunk)}\n\n" + # Configurable delay to simulate realistic streaming + stream_delay = float(os.getenv("STREAM_DELAY_SECONDS", "0.005")) + time.sleep(stream_delay) + + # Send final chunk + final_chunk = { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": {"content": ""}, + "finish_reason": "stop" + } + ] + } + yield f"data: {json.dumps(final_chunk)}\n\n" + yield "data: [DONE]\n\n" + + return Response( + generate_stream(), + mimetype='text/event-stream', + headers={ + 'Cache-Control': 'no-cache', + 'Connection': 'keep-alive', + 'Access-Control-Allow-Origin': '*', + } + ) + +@app.route('/health', methods=['GET']) +def health(): + return jsonify({"status": "healthy", "type": "openai-mock"}) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='OpenAI-compatible mock server') + parser.add_argument('--port', type=int, default=8081, + help='Port to run the server on (default: 8081)') + args = parser.parse_args() + + port = args.port + + models = get_models() + print("Starting OpenAI-compatible mock server...") + print(f"- /models endpoint with: {[m['id'] for m in models['data']]}") + print("- OpenAI-formatted chat/completion responses with dynamic content") + print("- Streaming support with valid SSE format") + print(f"- Listening on: http://0.0.0.0:{port}") + app.run(host='0.0.0.0', port=port, debug=False) diff --git a/docs/source/distributions/k8s-benchmark/profile_running_server.sh b/docs/source/distributions/k8s-benchmark/profile_running_server.sh new file mode 100755 index 000000000..65d620583 --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/profile_running_server.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Script to profile an already running Llama Stack server +# Usage: ./profile_running_server.sh [duration_seconds] [output_file] + +DURATION=${1:-60} # Default 60 seconds +OUTPUT_FILE=${2:-"llama_stack_profile"} # Default output file + +echo "Looking for running Llama Stack server..." + +# Find the server PID +SERVER_PID=$(ps aux | grep "llama_stack.core.server.server" | grep -v grep | awk '{print $2}' | head -1) + + +if [ -z "$SERVER_PID" ]; then + echo "Error: No running Llama Stack server found" + echo "Please start your server first with:" + echo "LLAMA_STACK_LOGGING=\"all=ERROR\" MOCK_INFERENCE_URL=http://localhost:8080 SAFETY_MODEL=llama-guard3:1b uv run --with llama-stack python -m llama_stack.core.server.server docs/source/distributions/k8s-benchmark/stack_run_config.yaml" + exit 1 +fi + +echo "Found Llama Stack server with PID: $SERVER_PID" + +# Start py-spy profiling +echo "Starting py-spy profiling for ${DURATION} seconds..." +echo "Output will be saved to: ${OUTPUT_FILE}.svg" +echo "" +echo "You can now run your load test..." +echo "" + +# Get the full path to py-spy +PYSPY_PATH=$(which py-spy) + +# Check if running as root, if not, use sudo +if [ "$EUID" -ne 0 ]; then + echo "py-spy requires root permissions on macOS. Running with sudo..." + sudo "$PYSPY_PATH" record -o "${OUTPUT_FILE}.svg" -d ${DURATION} -p $SERVER_PID +else + "$PYSPY_PATH" record -o "${OUTPUT_FILE}.svg" -d ${DURATION} -p $SERVER_PID +fi + +echo "" +echo "Profiling completed! Results saved to: ${OUTPUT_FILE}.svg" +echo "" +echo "To view the flame graph:" +echo "open ${OUTPUT_FILE}.svg" diff --git a/docs/source/distributions/k8s-benchmark/run-benchmark.sh b/docs/source/distributions/k8s-benchmark/run-benchmark.sh new file mode 100755 index 000000000..e1c826143 --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/run-benchmark.sh @@ -0,0 +1,148 @@ +#!/usr/bin/env bash + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +set -euo pipefail + +# Default values +TARGET="stack" +DURATION=60 +CONCURRENT=10 + +# Parse command line arguments +usage() { + echo "Usage: $0 [options]" + echo "Options:" + echo " -t, --target Target to benchmark (default: stack)" + echo " -d, --duration Duration in seconds (default: 60)" + echo " -c, --concurrent Number of concurrent users (default: 10)" + echo " -h, --help Show this help message" + echo "" + echo "Examples:" + echo " $0 --target vllm # Benchmark vLLM direct" + echo " $0 --target stack # Benchmark Llama Stack (default)" + echo " $0 -t vllm -d 120 -c 20 # vLLM with 120s duration, 20 users" +} + +while [[ $# -gt 0 ]]; do + case $1 in + -t|--target) + TARGET="$2" + shift 2 + ;; + -d|--duration) + DURATION="$2" + shift 2 + ;; + -c|--concurrent) + CONCURRENT="$2" + shift 2 + ;; + -h|--help) + usage + exit 0 + ;; + *) + echo "Unknown option: $1" + usage + exit 1 + ;; + esac +done + +# Validate target +if [[ "$TARGET" != "stack" && "$TARGET" != "vllm" ]]; then + echo "Error: Target must be 'stack' or 'vllm'" + usage + exit 1 +fi + +# Set configuration based on target +if [[ "$TARGET" == "vllm" ]]; then + BASE_URL="http://vllm-server:8000/v1" + JOB_NAME="vllm-benchmark-job" + echo "Benchmarking vLLM direct..." +else + BASE_URL="http://llama-stack-benchmark-service:8323/v1/openai/v1" + JOB_NAME="stack-benchmark-job" + echo "Benchmarking Llama Stack..." +fi + +echo "Configuration:" +echo " Target: $TARGET" +echo " Base URL: $BASE_URL" +echo " Duration: ${DURATION}s" +echo " Concurrent users: $CONCURRENT" +echo "" + +# Create temporary job yaml +TEMP_YAML="/tmp/benchmark-job-temp-$(date +%s).yaml" +cat > "$TEMP_YAML" << EOF +apiVersion: batch/v1 +kind: Job +metadata: + name: $JOB_NAME + namespace: default +spec: + template: + spec: + containers: + - name: benchmark + image: python:3.11-slim + command: ["/bin/bash"] + args: + - "-c" + - | + pip install aiohttp && + python3 /benchmark/benchmark.py \\ + --base-url $BASE_URL \\ + --model \${INFERENCE_MODEL} \\ + --duration $DURATION \\ + --concurrent $CONCURRENT + env: + - name: INFERENCE_MODEL + value: "meta-llama/Llama-3.2-3B-Instruct" + volumeMounts: + - name: benchmark-script + mountPath: /benchmark + resources: + requests: + memory: "256Mi" + cpu: "250m" + limits: + memory: "512Mi" + cpu: "500m" + volumes: + - name: benchmark-script + configMap: + name: benchmark-script + restartPolicy: Never + backoffLimit: 3 +EOF + +echo "Creating benchmark ConfigMap..." +kubectl create configmap benchmark-script \ + --from-file=benchmark.py=benchmark.py \ + --dry-run=client -o yaml | kubectl apply -f - + +echo "Cleaning up any existing benchmark job..." +kubectl delete job $JOB_NAME 2>/dev/null || true + +echo "Deploying benchmark Job..." +kubectl apply -f "$TEMP_YAML" + +echo "Waiting for job to start..." +kubectl wait --for=condition=Ready pod -l job-name=$JOB_NAME --timeout=60s + +echo "Following benchmark logs..." +kubectl logs -f job/$JOB_NAME + +echo "Job completed. Checking final status..." +kubectl get job $JOB_NAME + +# Clean up temporary file +rm -f "$TEMP_YAML" diff --git a/docs/source/distributions/k8s-benchmark/stack-configmap.yaml b/docs/source/distributions/k8s-benchmark/stack-configmap.yaml new file mode 100644 index 000000000..edf4ebd75 --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/stack-configmap.yaml @@ -0,0 +1,133 @@ +apiVersion: v1 +data: + stack_run_config.yaml: | + version: '2' + image_name: kubernetes-benchmark-demo + apis: + - agents + - inference + - safety + - telemetry + - tool_runtime + - vector_io + providers: + inference: + - provider_id: vllm-inference + provider_type: remote::vllm + config: + url: ${env.VLLM_URL:=http://localhost:8000/v1} + max_tokens: ${env.VLLM_MAX_TOKENS:=4096} + api_token: ${env.VLLM_API_TOKEN:=fake} + tls_verify: ${env.VLLM_TLS_VERIFY:=true} + - provider_id: vllm-safety + provider_type: remote::vllm + config: + url: ${env.VLLM_SAFETY_URL:=http://localhost:8000/v1} + max_tokens: ${env.VLLM_MAX_TOKENS:=4096} + api_token: ${env.VLLM_API_TOKEN:=fake} + tls_verify: ${env.VLLM_TLS_VERIFY:=true} + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers + config: {} + vector_io: + - provider_id: ${env.ENABLE_CHROMADB:+chromadb} + provider_type: remote::chromadb + config: + url: ${env.CHROMADB_URL:=} + kvstore: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: + excluded_categories: [] + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} + responses_store: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + service_name: "${env.OTEL_SERVICE_NAME:=\u200B}" + sinks: ${env.TELEMETRY_SINKS:=console} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:+} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:+} + max_results: 3 + - provider_id: rag-runtime + provider_type: inline::rag-runtime + config: {} + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + config: {} + metadata_store: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} + table_name: llamastack_kvstore + inference_store: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} + models: + - metadata: + embedding_dimension: 384 + model_id: all-MiniLM-L6-v2 + provider_id: sentence-transformers + model_type: embedding + - model_id: ${env.INFERENCE_MODEL} + provider_id: vllm-inference + model_type: llm + - model_id: ${env.SAFETY_MODEL} + provider_id: vllm-safety + model_type: llm + shields: + - shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B} + vector_dbs: [] + datasets: [] + scoring_fns: [] + benchmarks: [] + tool_groups: + - toolgroup_id: builtin::websearch + provider_id: tavily-search + - toolgroup_id: builtin::rag + provider_id: rag-runtime + server: + port: 8323 +kind: ConfigMap +metadata: + creationTimestamp: null + name: llama-stack-config diff --git a/docs/source/distributions/k8s-benchmark/stack-k8s.yaml.template b/docs/source/distributions/k8s-benchmark/stack-k8s.yaml.template new file mode 100644 index 000000000..9cb1e5be3 --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/stack-k8s.yaml.template @@ -0,0 +1,83 @@ +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: llama-benchmark-pvc +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: 1Gi +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: llama-stack-benchmark-server +spec: + replicas: 1 + selector: + matchLabels: + app.kubernetes.io/name: llama-stack-benchmark + app.kubernetes.io/component: server + template: + metadata: + labels: + app.kubernetes.io/name: llama-stack-benchmark + app.kubernetes.io/component: server + spec: + containers: + - name: llama-stack-benchmark + image: llamastack/distribution-starter:latest + imagePullPolicy: Always # since we have specified latest instead of a version + env: + - name: ENABLE_CHROMADB + value: "true" + - name: CHROMADB_URL + value: http://chromadb.default.svc.cluster.local:6000 + - name: POSTGRES_HOST + value: postgres-server.default.svc.cluster.local + - name: POSTGRES_PORT + value: "5432" + - name: INFERENCE_MODEL + value: "${INFERENCE_MODEL}" + - name: SAFETY_MODEL + value: "${SAFETY_MODEL}" + - name: TAVILY_SEARCH_API_KEY + value: "${TAVILY_SEARCH_API_KEY}" + - name: VLLM_URL + value: http://vllm-server.default.svc.cluster.local:8000/v1 + - name: VLLM_MAX_TOKENS + value: "3072" + - name: VLLM_SAFETY_URL + value: http://vllm-server-safety.default.svc.cluster.local:8001/v1 + - name: VLLM_TLS_VERIFY + value: "false" + command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8323"] + ports: + - containerPort: 8323 + volumeMounts: + - name: llama-storage + mountPath: /root/.llama + - name: llama-config + mountPath: /etc/config + volumes: + - name: llama-storage + persistentVolumeClaim: + claimName: llama-benchmark-pvc + - name: llama-config + configMap: + name: llama-stack-config +--- +apiVersion: v1 +kind: Service +metadata: + name: llama-stack-benchmark-service +spec: + selector: + app.kubernetes.io/name: llama-stack-benchmark + app.kubernetes.io/component: server + ports: + - name: http + port: 8323 + targetPort: 8323 + type: ClusterIP diff --git a/docs/source/distributions/k8s-benchmark/stack_run_config.yaml b/docs/source/distributions/k8s-benchmark/stack_run_config.yaml new file mode 100644 index 000000000..ceb1ba2d9 --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/stack_run_config.yaml @@ -0,0 +1,108 @@ +version: '2' +image_name: kubernetes-benchmark-demo +apis: +- agents +- inference +- telemetry +- tool_runtime +- vector_io +providers: + inference: + - provider_id: vllm-inference + provider_type: remote::vllm + config: + url: ${env.VLLM_URL:=http://localhost:8000/v1} + max_tokens: ${env.VLLM_MAX_TOKENS:=4096} + api_token: ${env.VLLM_API_TOKEN:=fake} + tls_verify: ${env.VLLM_TLS_VERIFY:=true} + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers + config: {} + vector_io: + - provider_id: ${env.ENABLE_CHROMADB:+chromadb} + provider_type: remote::chromadb + config: + url: ${env.CHROMADB_URL:=} + kvstore: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} + responses_store: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + service_name: "${env.OTEL_SERVICE_NAME:=\u200B}" + sinks: ${env.TELEMETRY_SINKS:=console} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:+} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:+} + max_results: 3 + - provider_id: rag-runtime + provider_type: inline::rag-runtime + config: {} + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + config: {} +metadata_store: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} + table_name: llamastack_kvstore +inference_store: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} +models: +- metadata: + embedding_dimension: 384 + model_id: all-MiniLM-L6-v2 + provider_id: sentence-transformers + model_type: embedding +- model_id: ${env.INFERENCE_MODEL} + provider_id: vllm-inference + model_type: llm +vector_dbs: [] +datasets: [] +scoring_fns: [] +benchmarks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::rag + provider_id: rag-runtime +server: + port: 8323 diff --git a/docs/source/distributions/k8s/stack-k8s.yaml.template b/docs/source/distributions/k8s/stack-k8s.yaml.template index ad5d2c716..dfc049f4f 100644 --- a/docs/source/distributions/k8s/stack-k8s.yaml.template +++ b/docs/source/distributions/k8s/stack-k8s.yaml.template @@ -40,19 +40,19 @@ spec: value: "3072" - name: VLLM_SAFETY_URL value: http://vllm-server-safety.default.svc.cluster.local:8001/v1 + - name: VLLM_TLS_VERIFY + value: "false" - name: POSTGRES_HOST value: postgres-server.default.svc.cluster.local - name: POSTGRES_PORT value: "5432" - - name: VLLM_TLS_VERIFY - value: "false" - name: INFERENCE_MODEL value: "${INFERENCE_MODEL}" - name: SAFETY_MODEL value: "${SAFETY_MODEL}" - name: TAVILY_SEARCH_API_KEY value: "${TAVILY_SEARCH_API_KEY}" - command: ["python", "-m", "llama_stack.core.server.server", "--config", "/etc/config/stack_run_config.yaml", "--port", "8321"] + command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8321"] ports: - containerPort: 8321 volumeMounts: diff --git a/docs/source/providers/agents/index.md b/docs/source/providers/agents/index.md index 92bf9edc0..a2c48d4b9 100644 --- a/docs/source/providers/agents/index.md +++ b/docs/source/providers/agents/index.md @@ -2,6 +2,15 @@ ## Overview +Agents API for creating and interacting with agentic systems. + + Main functionalities provided by this API: + - Create agents with specific instructions and ability to use tools. + - Interactions with agents are grouped into sessions ("threads"), and each interaction is called a "turn". + - Agents can be provided with various tools (see the ToolGroups and ToolRuntime APIs for more details). + - Agents can be provided with various shields (see the Safety API for more details). + - Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details. + This section contains documentation for all available providers for the **agents** API. ## Providers diff --git a/docs/source/providers/batches/index.md b/docs/source/providers/batches/index.md new file mode 100644 index 000000000..2a39a626c --- /dev/null +++ b/docs/source/providers/batches/index.md @@ -0,0 +1,21 @@ +# Batches + +## Overview + +Protocol for batch processing API operations. + + The Batches API enables efficient processing of multiple requests in a single operation, + particularly useful for processing large datasets, batch evaluation workflows, and + cost-effective inference at scale. + + Note: This API is currently under active development and may undergo changes. + +This section contains documentation for all available providers for the **batches** API. + +## Providers + +```{toctree} +:maxdepth: 1 + +inline_reference +``` diff --git a/docs/source/providers/batches/inline_reference.md b/docs/source/providers/batches/inline_reference.md new file mode 100644 index 000000000..a58e5124d --- /dev/null +++ b/docs/source/providers/batches/inline_reference.md @@ -0,0 +1,23 @@ +# inline::reference + +## Description + +Reference implementation of batches API with KVStore persistence. + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Configuration for the key-value store backend. | +| `max_concurrent_batches` | `` | No | 1 | Maximum number of concurrent batches to process simultaneously. | +| `max_concurrent_requests_per_batch` | `` | No | 10 | Maximum number of concurrent requests to process per batch. | + +## Sample Configuration + +```yaml +kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/batches.db + +``` + diff --git a/docs/source/providers/eval/index.md b/docs/source/providers/eval/index.md index d180d256c..a14fada1d 100644 --- a/docs/source/providers/eval/index.md +++ b/docs/source/providers/eval/index.md @@ -2,6 +2,8 @@ ## Overview +Llama Stack Evaluation API for running evaluations on model and agent candidates. + This section contains documentation for all available providers for the **eval** API. ## Providers diff --git a/docs/source/providers/external/external-providers-guide.md b/docs/source/providers/external/external-providers-guide.md index 2479d406f..e2d4ebea9 100644 --- a/docs/source/providers/external/external-providers-guide.md +++ b/docs/source/providers/external/external-providers-guide.md @@ -226,7 +226,7 @@ uv init name = "llama-stack-provider-ollama" version = "0.1.0" description = "Ollama provider for Llama Stack" -requires-python = ">=3.10" +requires-python = ">=3.12" dependencies = ["llama-stack", "pydantic", "ollama", "aiohttp"] ``` diff --git a/docs/source/providers/inference/index.md b/docs/source/providers/inference/index.md index 38781e5eb..b6d215474 100644 --- a/docs/source/providers/inference/index.md +++ b/docs/source/providers/inference/index.md @@ -2,6 +2,12 @@ ## Overview +Llama Stack Inference API for generating completions, chat completions, and embeddings. + + This API provides the raw interface to the underlying models. Two kinds of models are supported: + - LLM models: these models generate "raw" and "chat" (conversational) completions. + - Embedding models: these models generate embeddings to be used for semantic search. + This section contains documentation for all available providers for the **inference** API. ## Providers diff --git a/docs/source/providers/vector_io/inline_meta-reference.md b/docs/source/providers/vector_io/inline_meta-reference.md index 0aac445bd..6f269c441 100644 --- a/docs/source/providers/vector_io/inline_meta-reference.md +++ b/docs/source/providers/vector_io/inline_meta-reference.md @@ -21,5 +21,7 @@ kvstore: ## Deprecation Notice -⚠️ **Warning**: Please use the `inline::faiss` provider instead. +```{warning} +Please use the `inline::faiss` provider instead. +``` diff --git a/docs/source/providers/vector_io/inline_sqlite_vec.md b/docs/source/providers/vector_io/inline_sqlite_vec.md index 7ad8eb252..9e5654a50 100644 --- a/docs/source/providers/vector_io/inline_sqlite_vec.md +++ b/docs/source/providers/vector_io/inline_sqlite_vec.md @@ -25,5 +25,7 @@ kvstore: ## Deprecation Notice -⚠️ **Warning**: Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead. +```{warning} +Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead. +``` diff --git a/docs/source/providers/vector_io/remote_milvus.md b/docs/source/providers/vector_io/remote_milvus.md index 2af64b8bb..075423d04 100644 --- a/docs/source/providers/vector_io/remote_milvus.md +++ b/docs/source/providers/vector_io/remote_milvus.md @@ -204,7 +204,10 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi | `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend | | `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. | -> **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider. +```{note} + This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider. + ``` + ## Sample Configuration diff --git a/docs/source/references/llama_cli_reference/download_models.md b/docs/source/references/llama_cli_reference/download_models.md index e32099023..a9af65349 100644 --- a/docs/source/references/llama_cli_reference/download_models.md +++ b/docs/source/references/llama_cli_reference/download_models.md @@ -128,7 +128,9 @@ llama download --source huggingface --model-id Prompt-Guard-86M --ignore-pattern **Important:** Set your environment variable `HF_TOKEN` or pass in `--hf-token` to the command to validate your access. You can find your token at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens). -> **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored. +```{tip} +Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored. +``` ## List the downloaded models diff --git a/docs/source/references/llama_cli_reference/index.md b/docs/source/references/llama_cli_reference/index.md index 4ef76fe7d..09a8b7177 100644 --- a/docs/source/references/llama_cli_reference/index.md +++ b/docs/source/references/llama_cli_reference/index.md @@ -152,7 +152,9 @@ llama download --source huggingface --model-id Prompt-Guard-86M --ignore-pattern **Important:** Set your environment variable `HF_TOKEN` or pass in `--hf-token` to the command to validate your access. You can find your token at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens). -> **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored. +```{tip} +Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored. +``` ## List the downloaded models diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index e816da766..7dd3e9289 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -706,6 +706,7 @@ class Agents(Protocol): temperature: float | None = None, text: OpenAIResponseText | None = None, tools: list[OpenAIResponseInputTool] | None = None, + include: list[str] | None = None, max_infer_iters: int | None = 10, # this is an extension to the OpenAI API ) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]: """Create a new OpenAI response. @@ -713,6 +714,7 @@ class Agents(Protocol): :param input: Input message(s) to create the response. :param model: The underlying LLM used for completions. :param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses. + :param include: (Optional) Additional fields to include in the response. :returns: An OpenAIResponseObject. """ ... diff --git a/llama_stack/apis/agents/openai_responses.py b/llama_stack/apis/agents/openai_responses.py index 10cadf38f..591992479 100644 --- a/llama_stack/apis/agents/openai_responses.py +++ b/llama_stack/apis/agents/openai_responses.py @@ -170,6 +170,23 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel): type: Literal["web_search_call"] = "web_search_call" +class OpenAIResponseOutputMessageFileSearchToolCallResults(BaseModel): + """Search results returned by the file search operation. + + :param attributes: (Optional) Key-value attributes associated with the file + :param file_id: Unique identifier of the file containing the result + :param filename: Name of the file containing the result + :param score: Relevance score for this search result (between 0 and 1) + :param text: Text content of the search result + """ + + attributes: dict[str, Any] + file_id: str + filename: str + score: float + text: str + + @json_schema_type class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel): """File search tool call output message for OpenAI responses. @@ -185,7 +202,7 @@ class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel): queries: list[str] status: str type: Literal["file_search_call"] = "file_search_call" - results: list[dict[str, Any]] | None = None + results: list[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None @json_schema_type @@ -606,6 +623,62 @@ class OpenAIResponseObjectStreamResponseMcpCallCompleted(BaseModel): type: Literal["response.mcp_call.completed"] = "response.mcp_call.completed" +@json_schema_type +class OpenAIResponseContentPartOutputText(BaseModel): + type: Literal["output_text"] = "output_text" + text: str + # TODO: add annotations, logprobs, etc. + + +@json_schema_type +class OpenAIResponseContentPartRefusal(BaseModel): + type: Literal["refusal"] = "refusal" + refusal: str + + +OpenAIResponseContentPart = Annotated[ + OpenAIResponseContentPartOutputText | OpenAIResponseContentPartRefusal, + Field(discriminator="type"), +] +register_schema(OpenAIResponseContentPart, name="OpenAIResponseContentPart") + + +@json_schema_type +class OpenAIResponseObjectStreamResponseContentPartAdded(BaseModel): + """Streaming event for when a new content part is added to a response item. + + :param response_id: Unique identifier of the response containing this content + :param item_id: Unique identifier of the output item containing this content part + :param part: The content part that was added + :param sequence_number: Sequential number for ordering streaming events + :param type: Event type identifier, always "response.content_part.added" + """ + + response_id: str + item_id: str + part: OpenAIResponseContentPart + sequence_number: int + type: Literal["response.content_part.added"] = "response.content_part.added" + + +@json_schema_type +class OpenAIResponseObjectStreamResponseContentPartDone(BaseModel): + """Streaming event for when a content part is completed. + + :param response_id: Unique identifier of the response containing this content + :param item_id: Unique identifier of the output item containing this content part + :param part: The completed content part + :param sequence_number: Sequential number for ordering streaming events + :param type: Event type identifier, always "response.content_part.done" + """ + + response_id: str + item_id: str + part: OpenAIResponseContentPart + sequence_number: int + type: Literal["response.content_part.done"] = "response.content_part.done" + + OpenAIResponseObjectStream = Annotated[ OpenAIResponseObjectStreamResponseCreated | OpenAIResponseObjectStreamResponseOutputItemAdded @@ -625,6 +698,8 @@ OpenAIResponseObjectStream = Annotated[ | OpenAIResponseObjectStreamResponseMcpCallInProgress | OpenAIResponseObjectStreamResponseMcpCallFailed | OpenAIResponseObjectStreamResponseMcpCallCompleted + | OpenAIResponseObjectStreamResponseContentPartAdded + | OpenAIResponseObjectStreamResponseContentPartDone | OpenAIResponseObjectStreamResponseCompleted, Field(discriminator="type"), ] diff --git a/llama_stack/apis/batches/__init__.py b/llama_stack/apis/batches/__init__.py new file mode 100644 index 000000000..9ce7d3d75 --- /dev/null +++ b/llama_stack/apis/batches/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .batches import Batches, BatchObject, ListBatchesResponse + +__all__ = ["Batches", "BatchObject", "ListBatchesResponse"] diff --git a/llama_stack/apis/batches/batches.py b/llama_stack/apis/batches/batches.py new file mode 100644 index 000000000..9297d8597 --- /dev/null +++ b/llama_stack/apis/batches/batches.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Literal, Protocol, runtime_checkable + +from pydantic import BaseModel, Field + +from llama_stack.schema_utils import json_schema_type, webmethod + +try: + from openai.types import Batch as BatchObject +except ImportError as e: + raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e + + +@json_schema_type +class ListBatchesResponse(BaseModel): + """Response containing a list of batch objects.""" + + object: Literal["list"] = "list" + data: list[BatchObject] = Field(..., description="List of batch objects") + first_id: str | None = Field(default=None, description="ID of the first batch in the list") + last_id: str | None = Field(default=None, description="ID of the last batch in the list") + has_more: bool = Field(default=False, description="Whether there are more batches available") + + +@runtime_checkable +class Batches(Protocol): + """Protocol for batch processing API operations. + + The Batches API enables efficient processing of multiple requests in a single operation, + particularly useful for processing large datasets, batch evaluation workflows, and + cost-effective inference at scale. + + Note: This API is currently under active development and may undergo changes. + """ + + @webmethod(route="/openai/v1/batches", method="POST") + async def create_batch( + self, + input_file_id: str, + endpoint: str, + completion_window: Literal["24h"], + metadata: dict[str, str] | None = None, + ) -> BatchObject: + """Create a new batch for processing multiple API requests. + + :param input_file_id: The ID of an uploaded file containing requests for the batch. + :param endpoint: The endpoint to be used for all requests in the batch. + :param completion_window: The time window within which the batch should be processed. + :param metadata: Optional metadata for the batch. + :returns: The created batch object. + """ + ... + + @webmethod(route="/openai/v1/batches/{batch_id}", method="GET") + async def retrieve_batch(self, batch_id: str) -> BatchObject: + """Retrieve information about a specific batch. + + :param batch_id: The ID of the batch to retrieve. + :returns: The batch object. + """ + ... + + @webmethod(route="/openai/v1/batches/{batch_id}/cancel", method="POST") + async def cancel_batch(self, batch_id: str) -> BatchObject: + """Cancel a batch that is in progress. + + :param batch_id: The ID of the batch to cancel. + :returns: The updated batch object. + """ + ... + + @webmethod(route="/openai/v1/batches", method="GET") + async def list_batches( + self, + after: str | None = None, + limit: int = 20, + ) -> ListBatchesResponse: + """List all batches for the current user. + + :param after: A cursor for pagination; returns batches after this batch ID. + :param limit: Number of batches to return (default 20, max 100). + :returns: A list of batch objects. + """ + ... diff --git a/llama_stack/apis/common/errors.py b/llama_stack/apis/common/errors.py index 95d6ac18e..ec3d2b1ce 100644 --- a/llama_stack/apis/common/errors.py +++ b/llama_stack/apis/common/errors.py @@ -62,3 +62,20 @@ class SessionNotFoundError(ValueError): def __init__(self, session_name: str) -> None: message = f"Session '{session_name}' not found or access denied." super().__init__(message) + + +class ModelTypeError(TypeError): + """raised when a model is present but not the correct type""" + + def __init__(self, model_name: str, model_type: str, expected_model_type: str) -> None: + message = ( + f"Model '{model_name}' is of type '{model_type}' rather than the expected type '{expected_model_type}'" + ) + super().__init__(message) + + +class ConflictError(ValueError): + """raised when an operation cannot be performed due to a conflict with the current state""" + + def __init__(self, message: str) -> None: + super().__init__(message) diff --git a/llama_stack/apis/datatypes.py b/llama_stack/apis/datatypes.py index cabe46a2f..87fc95917 100644 --- a/llama_stack/apis/datatypes.py +++ b/llama_stack/apis/datatypes.py @@ -86,6 +86,7 @@ class Api(Enum, metaclass=DynamicApiMeta): :cvar inference: Text generation, chat completions, and embeddings :cvar safety: Content moderation and safety shields :cvar agents: Agent orchestration and execution + :cvar batches: Batch processing for asynchronous API requests :cvar vector_io: Vector database operations and queries :cvar datasetio: Dataset input/output operations :cvar scoring: Model output evaluation and scoring @@ -108,6 +109,7 @@ class Api(Enum, metaclass=DynamicApiMeta): inference = "inference" safety = "safety" agents = "agents" + batches = "batches" vector_io = "vector_io" datasetio = "datasetio" scoring = "scoring" diff --git a/llama_stack/apis/files/files.py b/llama_stack/apis/files/files.py index ba8701e23..a1b9dd4dc 100644 --- a/llama_stack/apis/files/files.py +++ b/llama_stack/apis/files/files.py @@ -22,6 +22,7 @@ class OpenAIFilePurpose(StrEnum): """ ASSISTANTS = "assistants" + BATCH = "batch" # TODO: Add other purposes as needed diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index 3f374460b..25ee03ec1 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum, StrEnum +from enum import Enum from typing import Any, Protocol, runtime_checkable from pydantic import BaseModel, Field @@ -15,27 +15,6 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod -# OpenAI Categories to return in the response -class OpenAICategories(StrEnum): - """ - Required set of categories in moderations api response - """ - - VIOLENCE = "violence" - VIOLENCE_GRAPHIC = "violence/graphic" - HARRASMENT = "harassment" - HARRASMENT_THREATENING = "harassment/threatening" - HATE = "hate" - HATE_THREATENING = "hate/threatening" - ILLICIT = "illicit" - ILLICIT_VIOLENT = "illicit/violent" - SEXUAL = "sexual" - SEXUAL_MINORS = "sexual/minors" - SELF_HARM = "self-harm" - SELF_HARM_INTENT = "self-harm/intent" - SELF_HARM_INSTRUCTIONS = "self-harm/instructions" - - @json_schema_type class ModerationObjectResults(BaseModel): """A moderation object. @@ -43,20 +22,6 @@ class ModerationObjectResults(BaseModel): :param categories: A list of the categories, and whether they are flagged or not. :param category_applied_input_types: A list of the categories along with the input type(s) that the score applies to. :param category_scores: A list of the categories along with their scores as predicted by model. - Required set of categories that need to be in response - - violence - - violence/graphic - - harassment - - harassment/threatening - - hate - - hate/threatening - - illicit - - illicit/violent - - sexual - - sexual/minors - - self-harm - - self-harm/intent - - self-harm/instructions """ flagged: bool diff --git a/llama_stack/core/build.py b/llama_stack/core/build.py index b3e35ecef..4b20588fd 100644 --- a/llama_stack/core/build.py +++ b/llama_stack/core/build.py @@ -91,7 +91,7 @@ def get_provider_dependencies( def print_pip_install_help(config: BuildConfig): - normal_deps, special_deps = get_provider_dependencies(config) + normal_deps, special_deps, _ = get_provider_dependencies(config) cprint( f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}", diff --git a/llama_stack/core/library_client.py b/llama_stack/core/library_client.py index 5fbbf1aff..a93fe509e 100644 --- a/llama_stack/core/library_client.py +++ b/llama_stack/core/library_client.py @@ -380,8 +380,17 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): json_content = json.dumps(convert_pydantic_to_json_value(result)) filtered_body = {k: v for k, v in body.items() if not isinstance(v, LibraryClientUploadFile)} + + status_code = httpx.codes.OK + + if options.method.upper() == "DELETE" and result is None: + status_code = httpx.codes.NO_CONTENT + + if status_code == httpx.codes.NO_CONTENT: + json_content = "" + mock_response = httpx.Response( - status_code=httpx.codes.OK, + status_code=status_code, content=json_content.encode("utf-8"), headers={ "Content-Type": "application/json", diff --git a/llama_stack/core/resolver.py b/llama_stack/core/resolver.py index 70c78fb01..7ac98dac8 100644 --- a/llama_stack/core/resolver.py +++ b/llama_stack/core/resolver.py @@ -8,6 +8,7 @@ import inspect from typing import Any from llama_stack.apis.agents import Agents +from llama_stack.apis.batches import Batches from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets @@ -75,6 +76,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> Api.agents: Agents, Api.inference: Inference, Api.inspect: Inspect, + Api.batches: Batches, Api.vector_io: VectorIO, Api.vector_dbs: VectorDBs, Api.models: Models, diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index 79ab7c34f..6a3f07247 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -18,7 +18,7 @@ from llama_stack.apis.common.content_types import ( InterleavedContent, InterleavedContentItem, ) -from llama_stack.apis.common.errors import ModelNotFoundError +from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError from llama_stack.apis.inference import ( BatchChatCompletionResponse, BatchCompletionResponse, @@ -65,7 +65,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus, Routin from llama_stack.providers.utils.inference.inference_store import InferenceStore from llama_stack.providers.utils.telemetry.tracing import get_current_span -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="inference") class InferenceRouter(Inference): @@ -177,6 +177,15 @@ class InferenceRouter(Inference): encoded = self.formatter.encode_content(messages) return len(encoded.tokens) if encoded and encoded.tokens else 0 + async def _get_model(self, model_id: str, expected_model_type: str) -> Model: + """takes a model id and gets model after ensuring that it is accessible and of the correct type""" + model = await self.routing_table.get_model(model_id) + if model is None: + raise ModelNotFoundError(model_id) + if model.model_type != expected_model_type: + raise ModelTypeError(model_id, model.model_type, expected_model_type) + return model + async def chat_completion( self, model_id: str, @@ -195,11 +204,7 @@ class InferenceRouter(Inference): ) if sampling_params is None: sampling_params = SamplingParams() - model = await self.routing_table.get_model(model_id) - if model is None: - raise ModelNotFoundError(model_id) - if model.model_type == ModelType.embedding: - raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") + model = await self._get_model(model_id, ModelType.llm) if tool_config: if tool_choice and tool_choice != tool_config.tool_choice: raise ValueError("tool_choice and tool_config.tool_choice must match") @@ -301,11 +306,7 @@ class InferenceRouter(Inference): logger.debug( f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}", ) - model = await self.routing_table.get_model(model_id) - if model is None: - raise ModelNotFoundError(model_id) - if model.model_type == ModelType.embedding: - raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") + model = await self._get_model(model_id, ModelType.llm) provider = await self.routing_table.get_provider_impl(model_id) params = dict( model_id=model_id, @@ -355,11 +356,7 @@ class InferenceRouter(Inference): task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: logger.debug(f"InferenceRouter.embeddings: {model_id}") - model = await self.routing_table.get_model(model_id) - if model is None: - raise ModelNotFoundError(model_id) - if model.model_type == ModelType.llm: - raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings") + await self._get_model(model_id, ModelType.embedding) provider = await self.routing_table.get_provider_impl(model_id) return await provider.embeddings( model_id=model_id, @@ -395,12 +392,7 @@ class InferenceRouter(Inference): logger.debug( f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}", ) - model_obj = await self.routing_table.get_model(model) - if model_obj is None: - raise ModelNotFoundError(model) - if model_obj.model_type == ModelType.embedding: - raise ValueError(f"Model '{model}' is an embedding model and does not support completions") - + model_obj = await self._get_model(model, ModelType.llm) params = dict( model=model_obj.identifier, prompt=prompt, @@ -476,11 +468,7 @@ class InferenceRouter(Inference): logger.debug( f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}", ) - model_obj = await self.routing_table.get_model(model) - if model_obj is None: - raise ModelNotFoundError(model) - if model_obj.model_type == ModelType.embedding: - raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions") + model_obj = await self._get_model(model, ModelType.llm) # Use the OpenAI client for a bit of extra input validation without # exposing the OpenAI client itself as part of our API surface @@ -567,12 +555,7 @@ class InferenceRouter(Inference): logger.debug( f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}", ) - model_obj = await self.routing_table.get_model(model) - if model_obj is None: - raise ModelNotFoundError(model) - if model_obj.model_type != ModelType.embedding: - raise ValueError(f"Model '{model}' is not an embedding model") - + model_obj = await self._get_model(model, ModelType.embedding) params = dict( model=model_obj.identifier, input=input, @@ -871,4 +854,5 @@ class InferenceRouter(Inference): model=model.identifier, object="chat.completion", ) + logger.debug(f"InferenceRouter.completion_response: {final_response}") await self.store.store_chat_completion(final_response, messages) diff --git a/llama_stack/core/routers/safety.py b/llama_stack/core/routers/safety.py index 9bf2b1bac..c76673d2a 100644 --- a/llama_stack/core/routers/safety.py +++ b/llama_stack/core/routers/safety.py @@ -10,7 +10,7 @@ from llama_stack.apis.inference import ( Message, ) from llama_stack.apis.safety import RunShieldResponse, Safety -from llama_stack.apis.safety.safety import ModerationObject, OpenAICategories +from llama_stack.apis.safety.safety import ModerationObject from llama_stack.apis.shields import Shield from llama_stack.log import get_logger from llama_stack.providers.datatypes import RoutingTable @@ -82,20 +82,5 @@ class SafetyRouter(Safety): input=input, model=model, ) - self._validate_required_categories_exist(response) return response - - def _validate_required_categories_exist(self, response: ModerationObject) -> None: - """Validate the ProviderImpl response contains the required Open AI moderations categories.""" - required_categories = list(map(str, OpenAICategories)) - - categories = response.results[0].categories - category_applied_input_types = response.results[0].category_applied_input_types - category_scores = response.results[0].category_scores - - for i in [categories, category_applied_input_types, category_scores]: - if not set(required_categories).issubset(set(i.keys())): - raise ValueError( - f"ProviderImpl response is missing required categories: {set(required_categories) - set(i.keys())}" - ) diff --git a/llama_stack/core/routing_tables/models.py b/llama_stack/core/routing_tables/models.py index c76619271..34c431e00 100644 --- a/llama_stack/core/routing_tables/models.py +++ b/llama_stack/core/routing_tables/models.py @@ -63,6 +63,8 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): async def get_provider_impl(self, model_id: str) -> Any: model = await lookup_model(self, model_id) + if model.provider_id not in self.impls_by_provider_id: + raise ValueError(f"Provider {model.provider_id} not found in the routing table") return self.impls_by_provider_id[model.provider_id] async def register_model( diff --git a/llama_stack/core/routing_tables/vector_dbs.py b/llama_stack/core/routing_tables/vector_dbs.py index c81a27a3b..e8dc46997 100644 --- a/llama_stack/core/routing_tables/vector_dbs.py +++ b/llama_stack/core/routing_tables/vector_dbs.py @@ -8,7 +8,7 @@ from typing import Any from pydantic import TypeAdapter -from llama_stack.apis.common.errors import ModelNotFoundError, VectorStoreNotFoundError +from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError, VectorStoreNotFoundError from llama_stack.apis.models import ModelType from llama_stack.apis.resource import ResourceType from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs @@ -66,7 +66,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): if model is None: raise ModelNotFoundError(embedding_model) if model.model_type != ModelType.embedding: - raise ValueError(f"Model {embedding_model} is not an embedding model") + raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding) if "embedding_dimension" not in model.metadata: raise ValueError(f"Model {embedding_model} does not have an embedding dimension") vector_db_data = { diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index fe5cc68d7..cbef8ef88 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -21,16 +21,18 @@ from importlib.metadata import version as parse_version from pathlib import Path from typing import Annotated, Any, get_origin +import httpx import rich.pretty import yaml from aiohttp import hdrs -from fastapi import Body, FastAPI, HTTPException, Request +from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi import Path as FastapiPath from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse from openai import BadRequestError from pydantic import BaseModel, ValidationError +from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.cli.utils import add_config_distro_args, get_config_from_args from llama_stack.core.access_control.access_control import AccessDeniedError @@ -115,7 +117,7 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro if isinstance(exc, RequestValidationError): return HTTPException( - status_code=400, + status_code=httpx.codes.BAD_REQUEST, detail={ "errors": [ { @@ -127,21 +129,25 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro ] }, ) + elif isinstance(exc, ConflictError): + return HTTPException(status_code=409, detail=str(exc)) + elif isinstance(exc, ResourceNotFoundError): + return HTTPException(status_code=404, detail=str(exc)) elif isinstance(exc, ValueError): - return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}") + return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}") elif isinstance(exc, BadRequestError): - return HTTPException(status_code=400, detail=str(exc)) + return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=str(exc)) elif isinstance(exc, PermissionError | AccessDeniedError): - return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}") + return HTTPException(status_code=httpx.codes.FORBIDDEN, detail=f"Permission denied: {str(exc)}") elif isinstance(exc, asyncio.TimeoutError | TimeoutError): - return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}") + return HTTPException(status_code=httpx.codes.GATEWAY_TIMEOUT, detail=f"Operation timed out: {str(exc)}") elif isinstance(exc, NotImplementedError): - return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}") + return HTTPException(status_code=httpx.codes.NOT_IMPLEMENTED, detail=f"Not implemented: {str(exc)}") elif isinstance(exc, AuthenticationRequiredError): - return HTTPException(status_code=401, detail=f"Authentication required: {str(exc)}") + return HTTPException(status_code=httpx.codes.UNAUTHORIZED, detail=f"Authentication required: {str(exc)}") else: return HTTPException( - status_code=500, + status_code=httpx.codes.INTERNAL_SERVER_ERROR, detail="Internal server error: An unexpected error occurred.", ) @@ -180,7 +186,6 @@ async def sse_generator(event_gen_coroutine): event_gen = await event_gen_coroutine async for item in event_gen: yield create_sse_event(item) - await asyncio.sleep(0.01) except asyncio.CancelledError: logger.info("Generator cancelled") if event_gen: @@ -236,6 +241,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: result = await maybe_await(value) if isinstance(result, PaginatedResponse) and result.url is None: result.url = route + + if method.upper() == "DELETE" and result is None: + return Response(status_code=httpx.codes.NO_CONTENT) + return result except Exception as e: if logger.isEnabledFor(logging.DEBUG): @@ -352,7 +361,7 @@ class ClientVersionMiddleware: await send( { "type": "http.response.start", - "status": 426, + "status": httpx.codes.UPGRADE_REQUIRED, "headers": [[b"content-type", b"application/json"]], } ) diff --git a/llama_stack/distributions/ci-tests/build.yaml b/llama_stack/distributions/ci-tests/build.yaml index e6e699b62..676ed18d2 100644 --- a/llama_stack/distributions/ci-tests/build.yaml +++ b/llama_stack/distributions/ci-tests/build.yaml @@ -48,6 +48,8 @@ distribution_spec: - provider_type: remote::tavily-search - provider_type: inline::rag-runtime - provider_type: remote::model-context-protocol + batches: + - provider_type: inline::reference image_type: venv additional_pip_packages: - aiosqlite diff --git a/llama_stack/distributions/ci-tests/run.yaml b/llama_stack/distributions/ci-tests/run.yaml index 05e1b4576..dd4e04e50 100644 --- a/llama_stack/distributions/ci-tests/run.yaml +++ b/llama_stack/distributions/ci-tests/run.yaml @@ -2,6 +2,7 @@ version: 2 image_name: ci-tests apis: - agents +- batches - datasetio - eval - files @@ -204,6 +205,13 @@ providers: provider_type: inline::rag-runtime - provider_id: model-context-protocol provider_type: remote::model-context-protocol + batches: + - provider_id: reference + provider_type: inline::reference + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/batches.db metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/registry.db diff --git a/llama_stack/distributions/dell/dell.py b/llama_stack/distributions/dell/dell.py index b561ea00e..e3bf0ee03 100644 --- a/llama_stack/distributions/dell/dell.py +++ b/llama_stack/distributions/dell/dell.py @@ -16,6 +16,7 @@ from llama_stack.distributions.template import DistributionTemplate, RunConfigSe from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) +from llama_stack.providers.remote.vector_io.chroma import ChromaVectorIOConfig def get_distribution_template() -> DistributionTemplate: @@ -71,9 +72,10 @@ def get_distribution_template() -> DistributionTemplate: chromadb_provider = Provider( provider_id="chromadb", provider_type="remote::chromadb", - config={ - "url": "${env.CHROMA_URL}", - }, + config=ChromaVectorIOConfig.sample_run_config( + f"~/.llama/distributions/{name}/", + url="${env.CHROMADB_URL:=}", + ), ) inference_model = ModelInput( diff --git a/llama_stack/distributions/dell/run-with-safety.yaml b/llama_stack/distributions/dell/run-with-safety.yaml index ecc6729eb..d89c92aa1 100644 --- a/llama_stack/distributions/dell/run-with-safety.yaml +++ b/llama_stack/distributions/dell/run-with-safety.yaml @@ -26,7 +26,10 @@ providers: - provider_id: chromadb provider_type: remote::chromadb config: - url: ${env.CHROMA_URL} + url: ${env.CHROMADB_URL:=} + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell/}/chroma_remote_registry.db safety: - provider_id: llama-guard provider_type: inline::llama-guard diff --git a/llama_stack/distributions/dell/run.yaml b/llama_stack/distributions/dell/run.yaml index fc2553526..7397410ba 100644 --- a/llama_stack/distributions/dell/run.yaml +++ b/llama_stack/distributions/dell/run.yaml @@ -22,7 +22,10 @@ providers: - provider_id: chromadb provider_type: remote::chromadb config: - url: ${env.CHROMA_URL} + url: ${env.CHROMADB_URL:=} + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell/}/chroma_remote_registry.db safety: - provider_id: llama-guard provider_type: inline::llama-guard diff --git a/llama_stack/distributions/starter/build.yaml b/llama_stack/distributions/starter/build.yaml index 1a4f81d49..549bb4529 100644 --- a/llama_stack/distributions/starter/build.yaml +++ b/llama_stack/distributions/starter/build.yaml @@ -48,6 +48,8 @@ distribution_spec: - provider_type: remote::tavily-search - provider_type: inline::rag-runtime - provider_type: remote::model-context-protocol + batches: + - provider_type: inline::reference image_type: venv additional_pip_packages: - aiosqlite diff --git a/llama_stack/distributions/starter/run.yaml b/llama_stack/distributions/starter/run.yaml index 46bd12956..d64c275cb 100644 --- a/llama_stack/distributions/starter/run.yaml +++ b/llama_stack/distributions/starter/run.yaml @@ -2,6 +2,7 @@ version: 2 image_name: starter apis: - agents +- batches - datasetio - eval - files @@ -204,6 +205,13 @@ providers: provider_type: inline::rag-runtime - provider_id: model-context-protocol provider_type: remote::model-context-protocol + batches: + - provider_id: reference + provider_type: inline::reference + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/batches.db metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/registry.db diff --git a/llama_stack/distributions/starter/starter.py b/llama_stack/distributions/starter/starter.py index 0270b68ad..498a12080 100644 --- a/llama_stack/distributions/starter/starter.py +++ b/llama_stack/distributions/starter/starter.py @@ -139,6 +139,9 @@ def get_distribution_template() -> DistributionTemplate: BuildProvider(provider_type="inline::rag-runtime"), BuildProvider(provider_type="remote::model-context-protocol"), ], + "batches": [ + BuildProvider(provider_type="inline::reference"), + ], } files_provider = Provider( provider_id="meta-reference-files", diff --git a/llama_stack/log.py b/llama_stack/log.py index 0a2d63ef6..7507aface 100644 --- a/llama_stack/log.py +++ b/llama_stack/log.py @@ -32,6 +32,7 @@ CATEGORIES = [ "tools", "client", "telemetry", + "openai_responses", ] # Initialize category levels with default level diff --git a/llama_stack/models/llama/llama3/chat_format.py b/llama_stack/models/llama/llama3/chat_format.py index 0a973cf0c..1f88a1699 100644 --- a/llama_stack/models/llama/llama3/chat_format.py +++ b/llama_stack/models/llama/llama3/chat_format.py @@ -236,6 +236,7 @@ class ChatFormat: arguments_json=json.dumps(tool_arguments), ) ) + content = "" return RawMessage( role="assistant", diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 15695ec48..30196c429 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -48,8 +48,8 @@ from llama_stack.providers.utils.responses.responses_store import ResponsesStore from .agent_instance import ChatAgent from .config import MetaReferenceAgentsImplConfig -from .openai_responses import OpenAIResponsesImpl from .persistence import AgentInfo +from .responses.openai_responses import OpenAIResponsesImpl logger = logging.getLogger() @@ -327,10 +327,21 @@ class MetaReferenceAgentsImpl(Agents): temperature: float | None = None, text: OpenAIResponseText | None = None, tools: list[OpenAIResponseInputTool] | None = None, + include: list[str] | None = None, max_infer_iters: int | None = 10, ) -> OpenAIResponseObject: return await self.openai_responses_impl.create_openai_response( - input, model, instructions, previous_response_id, store, stream, temperature, text, tools, max_infer_iters + input, + model, + instructions, + previous_response_id, + store, + stream, + temperature, + text, + tools, + include, + max_infer_iters, ) async def list_openai_responses( diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py deleted file mode 100644 index 7eb2b3897..000000000 --- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py +++ /dev/null @@ -1,880 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import asyncio -import json -import time -import uuid -from collections.abc import AsyncIterator -from typing import Any - -from openai.types.chat import ChatCompletionToolParam -from pydantic import BaseModel - -from llama_stack.apis.agents import Order -from llama_stack.apis.agents.openai_responses import ( - AllowedToolsFilter, - ListOpenAIResponseInputItem, - ListOpenAIResponseObject, - OpenAIDeleteResponseObject, - OpenAIResponseInput, - OpenAIResponseInputFunctionToolCallOutput, - OpenAIResponseInputMessageContent, - OpenAIResponseInputMessageContentImage, - OpenAIResponseInputMessageContentText, - OpenAIResponseInputTool, - OpenAIResponseInputToolFileSearch, - OpenAIResponseInputToolMCP, - OpenAIResponseMessage, - OpenAIResponseObject, - OpenAIResponseObjectStream, - OpenAIResponseObjectStreamResponseCompleted, - OpenAIResponseObjectStreamResponseCreated, - OpenAIResponseObjectStreamResponseOutputTextDelta, - OpenAIResponseOutput, - OpenAIResponseOutputMessageContent, - OpenAIResponseOutputMessageContentOutputText, - OpenAIResponseOutputMessageFileSearchToolCall, - OpenAIResponseOutputMessageFunctionToolCall, - OpenAIResponseOutputMessageMCPListTools, - OpenAIResponseOutputMessageWebSearchToolCall, - OpenAIResponseText, - OpenAIResponseTextFormat, - WebSearchToolTypes, -) -from llama_stack.apis.common.content_types import TextContentItem -from llama_stack.apis.inference import ( - Inference, - OpenAIAssistantMessageParam, - OpenAIChatCompletion, - OpenAIChatCompletionContentPartImageParam, - OpenAIChatCompletionContentPartParam, - OpenAIChatCompletionContentPartTextParam, - OpenAIChatCompletionToolCall, - OpenAIChatCompletionToolCallFunction, - OpenAIChoice, - OpenAIDeveloperMessageParam, - OpenAIImageURL, - OpenAIJSONSchema, - OpenAIMessageParam, - OpenAIResponseFormatJSONObject, - OpenAIResponseFormatJSONSchema, - OpenAIResponseFormatParam, - OpenAIResponseFormatText, - OpenAISystemMessageParam, - OpenAIToolMessageParam, - OpenAIUserMessageParam, -) -from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime -from llama_stack.apis.vector_io import VectorIO -from llama_stack.log import get_logger -from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition -from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool -from llama_stack.providers.utils.responses.responses_store import ResponsesStore - -logger = get_logger(name=__name__, category="openai_responses") - -OPENAI_RESPONSES_PREFIX = "openai_responses:" - - -async def _convert_response_content_to_chat_content( - content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent], -) -> str | list[OpenAIChatCompletionContentPartParam]: - """ - Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts. - - The content schemas of each API look similar, but are not exactly the same. - """ - if isinstance(content, str): - return content - - converted_parts = [] - for content_part in content: - if isinstance(content_part, OpenAIResponseInputMessageContentText): - converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text)) - elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText): - converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text)) - elif isinstance(content_part, OpenAIResponseInputMessageContentImage): - if content_part.image_url: - image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail) - converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url)) - elif isinstance(content_part, str): - converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part)) - else: - raise ValueError( - f"Llama Stack OpenAI Responses does not yet support content type '{type(content_part)}' in this context" - ) - return converted_parts - - -async def _convert_response_input_to_chat_messages( - input: str | list[OpenAIResponseInput], -) -> list[OpenAIMessageParam]: - """ - Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages. - """ - messages: list[OpenAIMessageParam] = [] - if isinstance(input, list): - for input_item in input: - if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput): - messages.append( - OpenAIToolMessageParam( - content=input_item.output, - tool_call_id=input_item.call_id, - ) - ) - elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall): - tool_call = OpenAIChatCompletionToolCall( - index=0, - id=input_item.call_id, - function=OpenAIChatCompletionToolCallFunction( - name=input_item.name, - arguments=input_item.arguments, - ), - ) - messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call])) - else: - content = await _convert_response_content_to_chat_content(input_item.content) - message_type = await _get_message_type_by_role(input_item.role) - if message_type is None: - raise ValueError( - f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context" - ) - messages.append(message_type(content=content)) - else: - messages.append(OpenAIUserMessageParam(content=input)) - return messages - - -async def _convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage: - """ - Convert an OpenAI Chat Completion choice into an OpenAI Response output message. - """ - output_content = "" - if isinstance(choice.message.content, str): - output_content = choice.message.content - elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam): - output_content = choice.message.content.text - else: - raise ValueError( - f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}" - ) - - return OpenAIResponseMessage( - id=f"msg_{uuid.uuid4()}", - content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)], - status="completed", - role="assistant", - ) - - -async def _convert_response_text_to_chat_response_format(text: OpenAIResponseText) -> OpenAIResponseFormatParam: - """ - Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format. - """ - if not text.format or text.format["type"] == "text": - return OpenAIResponseFormatText(type="text") - if text.format["type"] == "json_object": - return OpenAIResponseFormatJSONObject() - if text.format["type"] == "json_schema": - return OpenAIResponseFormatJSONSchema( - json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"]) - ) - raise ValueError(f"Unsupported text format: {text.format}") - - -async def _get_message_type_by_role(role: str): - role_to_type = { - "user": OpenAIUserMessageParam, - "system": OpenAISystemMessageParam, - "assistant": OpenAIAssistantMessageParam, - "developer": OpenAIDeveloperMessageParam, - } - return role_to_type.get(role) - - -class OpenAIResponsePreviousResponseWithInputItems(BaseModel): - input_items: ListOpenAIResponseInputItem - response: OpenAIResponseObject - - -class ChatCompletionContext(BaseModel): - model: str - messages: list[OpenAIMessageParam] - response_tools: list[OpenAIResponseInputTool] | None = None - chat_tools: list[ChatCompletionToolParam] | None = None - mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] - temperature: float | None - response_format: OpenAIResponseFormatParam - - -class OpenAIResponsesImpl: - def __init__( - self, - inference_api: Inference, - tool_groups_api: ToolGroups, - tool_runtime_api: ToolRuntime, - responses_store: ResponsesStore, - vector_io_api: VectorIO, # VectorIO - ): - self.inference_api = inference_api - self.tool_groups_api = tool_groups_api - self.tool_runtime_api = tool_runtime_api - self.responses_store = responses_store - self.vector_io_api = vector_io_api - - async def _prepend_previous_response( - self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None - ): - if previous_response_id: - previous_response_with_input = await self.responses_store.get_response_object(previous_response_id) - - # previous response input items - new_input_items = previous_response_with_input.input - - # previous response output items - new_input_items.extend(previous_response_with_input.output) - - # new input items from the current request - if isinstance(input, str): - new_input_items.append(OpenAIResponseMessage(content=input, role="user")) - else: - new_input_items.extend(input) - - input = new_input_items - - return input - - async def _prepend_instructions(self, messages, instructions): - if instructions: - messages.insert(0, OpenAISystemMessageParam(content=instructions)) - - async def get_openai_response( - self, - response_id: str, - ) -> OpenAIResponseObject: - response_with_input = await self.responses_store.get_response_object(response_id) - return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"}) - - async def list_openai_responses( - self, - after: str | None = None, - limit: int | None = 50, - model: str | None = None, - order: Order | None = Order.desc, - ) -> ListOpenAIResponseObject: - return await self.responses_store.list_responses(after, limit, model, order) - - async def list_openai_response_input_items( - self, - response_id: str, - after: str | None = None, - before: str | None = None, - include: list[str] | None = None, - limit: int | None = 20, - order: Order | None = Order.desc, - ) -> ListOpenAIResponseInputItem: - """List input items for a given OpenAI response. - - :param response_id: The ID of the response to retrieve input items for. - :param after: An item ID to list items after, used for pagination. - :param before: An item ID to list items before, used for pagination. - :param include: Additional fields to include in the response. - :param limit: A limit on the number of objects to be returned. - :param order: The order to return the input items in. - :returns: An ListOpenAIResponseInputItem. - """ - return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order) - - async def _store_response( - self, - response: OpenAIResponseObject, - input: str | list[OpenAIResponseInput], - ) -> None: - new_input_id = f"msg_{uuid.uuid4()}" - if isinstance(input, str): - # synthesize a message from the input string - input_content = OpenAIResponseInputMessageContentText(text=input) - input_content_item = OpenAIResponseMessage( - role="user", - content=[input_content], - id=new_input_id, - ) - input_items_data = [input_content_item] - else: - # we already have a list of messages - input_items_data = [] - for input_item in input: - if isinstance(input_item, OpenAIResponseMessage): - # These may or may not already have an id, so dump to dict, check for id, and add if missing - input_item_dict = input_item.model_dump() - if "id" not in input_item_dict: - input_item_dict["id"] = new_input_id - input_items_data.append(OpenAIResponseMessage(**input_item_dict)) - else: - input_items_data.append(input_item) - - await self.responses_store.store_response_object( - response_object=response, - input=input_items_data, - ) - - async def create_openai_response( - self, - input: str | list[OpenAIResponseInput], - model: str, - instructions: str | None = None, - previous_response_id: str | None = None, - store: bool | None = True, - stream: bool | None = False, - temperature: float | None = None, - text: OpenAIResponseText | None = None, - tools: list[OpenAIResponseInputTool] | None = None, - max_infer_iters: int | None = 10, - ): - stream = bool(stream) - text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text - - stream_gen = self._create_streaming_response( - input=input, - model=model, - instructions=instructions, - previous_response_id=previous_response_id, - store=store, - temperature=temperature, - text=text, - tools=tools, - max_infer_iters=max_infer_iters, - ) - - if stream: - return stream_gen - else: - response = None - async for stream_chunk in stream_gen: - if stream_chunk.type == "response.completed": - if response is not None: - raise ValueError("The response stream completed multiple times! Earlier response: {response}") - response = stream_chunk.response - # don't leave the generator half complete! - - if response is None: - raise ValueError("The response stream never completed") - return response - - async def _create_streaming_response( - self, - input: str | list[OpenAIResponseInput], - model: str, - instructions: str | None = None, - previous_response_id: str | None = None, - store: bool | None = True, - temperature: float | None = None, - text: OpenAIResponseText | None = None, - tools: list[OpenAIResponseInputTool] | None = None, - max_infer_iters: int | None = 10, - ) -> AsyncIterator[OpenAIResponseObjectStream]: - output_messages: list[OpenAIResponseOutput] = [] - - # Input preprocessing - input = await self._prepend_previous_response(input, previous_response_id) - messages = await _convert_response_input_to_chat_messages(input) - await self._prepend_instructions(messages, instructions) - - # Structured outputs - response_format = await _convert_response_text_to_chat_response_format(text) - - # Tool setup, TODO: refactor this slightly since this can also yield events - chat_tools, mcp_tool_to_server, mcp_list_message = ( - await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None) - ) - if mcp_list_message: - output_messages.append(mcp_list_message) - - ctx = ChatCompletionContext( - model=model, - messages=messages, - response_tools=tools, - chat_tools=chat_tools, - mcp_tool_to_server=mcp_tool_to_server, - temperature=temperature, - response_format=response_format, - ) - - # Create initial response and emit response.created immediately - response_id = f"resp-{uuid.uuid4()}" - created_at = int(time.time()) - - initial_response = OpenAIResponseObject( - created_at=created_at, - id=response_id, - model=model, - object="response", - status="in_progress", - output=output_messages.copy(), - text=text, - ) - - yield OpenAIResponseObjectStreamResponseCreated(response=initial_response) - - n_iter = 0 - messages = ctx.messages.copy() - - while True: - completion_result = await self.inference_api.openai_chat_completion( - model=ctx.model, - messages=messages, - tools=ctx.chat_tools, - stream=True, - temperature=ctx.temperature, - response_format=ctx.response_format, - ) - - # Process streaming chunks and build complete response - chat_response_id = "" - chat_response_content = [] - chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {} - chunk_created = 0 - chunk_model = "" - chunk_finish_reason = "" - sequence_number = 0 - - # Create a placeholder message item for delta events - message_item_id = f"msg_{uuid.uuid4()}" - - async for chunk in completion_result: - chat_response_id = chunk.id - chunk_created = chunk.created - chunk_model = chunk.model - for chunk_choice in chunk.choices: - # Emit incremental text content as delta events - if chunk_choice.delta.content: - sequence_number += 1 - yield OpenAIResponseObjectStreamResponseOutputTextDelta( - content_index=0, - delta=chunk_choice.delta.content, - item_id=message_item_id, - output_index=0, - sequence_number=sequence_number, - ) - - # Collect content for final response - chat_response_content.append(chunk_choice.delta.content or "") - if chunk_choice.finish_reason: - chunk_finish_reason = chunk_choice.finish_reason - - # Aggregate tool call arguments across chunks - if chunk_choice.delta.tool_calls: - for tool_call in chunk_choice.delta.tool_calls: - response_tool_call = chat_response_tool_calls.get(tool_call.index, None) - if response_tool_call: - # Don't attempt to concatenate arguments if we don't have any new argumentsAdd commentMore actions - if tool_call.function.arguments: - # Guard against an initial None argument before we concatenate - response_tool_call.function.arguments = ( - response_tool_call.function.arguments or "" - ) + tool_call.function.arguments - else: - tool_call_dict: dict[str, Any] = tool_call.model_dump() - tool_call_dict.pop("type", None) - response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict) - chat_response_tool_calls[tool_call.index] = response_tool_call - - # Convert collected chunks to complete response - if chat_response_tool_calls: - tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())] - else: - tool_calls = None - assistant_message = OpenAIAssistantMessageParam( - content="".join(chat_response_content), - tool_calls=tool_calls, - ) - current_response = OpenAIChatCompletion( - id=chat_response_id, - choices=[ - OpenAIChoice( - message=assistant_message, - finish_reason=chunk_finish_reason, - index=0, - ) - ], - created=chunk_created, - model=chunk_model, - ) - - function_tool_calls = [] - non_function_tool_calls = [] - - next_turn_messages = messages.copy() - for choice in current_response.choices: - next_turn_messages.append(choice.message) - - if choice.message.tool_calls and tools: - for tool_call in choice.message.tool_calls: - if _is_function_tool_call(tool_call, tools): - function_tool_calls.append(tool_call) - else: - non_function_tool_calls.append(tool_call) - else: - output_messages.append(await _convert_chat_choice_to_response_message(choice)) - - # execute non-function tool calls - for tool_call in non_function_tool_calls: - tool_call_log, tool_response_message = await self._execute_tool_call(tool_call, ctx) - if tool_call_log: - output_messages.append(tool_call_log) - if tool_response_message: - next_turn_messages.append(tool_response_message) - - for tool_call in function_tool_calls: - output_messages.append( - OpenAIResponseOutputMessageFunctionToolCall( - arguments=tool_call.function.arguments or "", - call_id=tool_call.id, - name=tool_call.function.name or "", - id=f"fc_{uuid.uuid4()}", - status="completed", - ) - ) - - if not function_tool_calls and not non_function_tool_calls: - break - - if function_tool_calls: - logger.info("Exiting inference loop since there is a function (client-side) tool call") - break - - n_iter += 1 - if n_iter >= max_infer_iters: - logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {max_infer_iters=}") - break - - messages = next_turn_messages - - # Create final response - final_response = OpenAIResponseObject( - created_at=created_at, - id=response_id, - model=model, - object="response", - status="completed", - text=text, - output=output_messages, - ) - - # Emit response.completed - yield OpenAIResponseObjectStreamResponseCompleted(response=final_response) - - if store: - await self._store_response( - response=final_response, - input=input, - ) - - async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject: - return await self.responses_store.delete_response_object(response_id) - - async def _convert_response_tools_to_chat_tools( - self, tools: list[OpenAIResponseInputTool] - ) -> tuple[ - list[ChatCompletionToolParam], - dict[str, OpenAIResponseInputToolMCP], - OpenAIResponseOutput | None, - ]: - from llama_stack.apis.agents.openai_responses import ( - MCPListToolsTool, - ) - from llama_stack.apis.tools import Tool - - mcp_tool_to_server = {} - - def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam: - tool_def = ToolDefinition( - tool_name=tool_name, - description=tool.description, - parameters={ - param.name: ToolParamDefinition( - param_type=param.parameter_type, - description=param.description, - required=param.required, - default=param.default, - ) - for param in tool.parameters - }, - ) - return convert_tooldef_to_openai_tool(tool_def) - - mcp_list_message = None - chat_tools: list[ChatCompletionToolParam] = [] - for input_tool in tools: - # TODO: Handle other tool types - if input_tool.type == "function": - chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump())) - elif input_tool.type in WebSearchToolTypes: - tool_name = "web_search" - tool = await self.tool_groups_api.get_tool(tool_name) - if not tool: - raise ValueError(f"Tool {tool_name} not found") - chat_tools.append(make_openai_tool(tool_name, tool)) - elif input_tool.type == "file_search": - tool_name = "knowledge_search" - tool = await self.tool_groups_api.get_tool(tool_name) - if not tool: - raise ValueError(f"Tool {tool_name} not found") - chat_tools.append(make_openai_tool(tool_name, tool)) - elif input_tool.type == "mcp": - from llama_stack.providers.utils.tools.mcp import list_mcp_tools - - always_allowed = None - never_allowed = None - if input_tool.allowed_tools: - if isinstance(input_tool.allowed_tools, list): - always_allowed = input_tool.allowed_tools - elif isinstance(input_tool.allowed_tools, AllowedToolsFilter): - always_allowed = input_tool.allowed_tools.always - never_allowed = input_tool.allowed_tools.never - - tool_defs = await list_mcp_tools( - endpoint=input_tool.server_url, - headers=input_tool.headers or {}, - ) - - mcp_list_message = OpenAIResponseOutputMessageMCPListTools( - id=f"mcp_list_{uuid.uuid4()}", - status="completed", - server_label=input_tool.server_label, - tools=[], - ) - for t in tool_defs.data: - if never_allowed and t.name in never_allowed: - continue - if not always_allowed or t.name in always_allowed: - chat_tools.append(make_openai_tool(t.name, t)) - if t.name in mcp_tool_to_server: - raise ValueError(f"Duplicate tool name {t.name} found for server {input_tool.server_label}") - mcp_tool_to_server[t.name] = input_tool - mcp_list_message.tools.append( - MCPListToolsTool( - name=t.name, - description=t.description, - input_schema={ - "type": "object", - "properties": { - p.name: { - "type": p.parameter_type, - "description": p.description, - } - for p in t.parameters - }, - "required": [p.name for p in t.parameters if p.required], - }, - ) - ) - else: - raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}") - return chat_tools, mcp_tool_to_server, mcp_list_message - - async def _execute_knowledge_search_via_vector_store( - self, - query: str, - response_file_search_tool: OpenAIResponseInputToolFileSearch, - ) -> ToolInvocationResult: - """Execute knowledge search using vector_stores.search API with filters support.""" - search_results = [] - - # Create search tasks for all vector stores - async def search_single_store(vector_store_id): - try: - search_response = await self.vector_io_api.openai_search_vector_store( - vector_store_id=vector_store_id, - query=query, - filters=response_file_search_tool.filters, - max_num_results=response_file_search_tool.max_num_results, - ranking_options=response_file_search_tool.ranking_options, - rewrite_query=False, - ) - return search_response.data - except Exception as e: - logger.warning(f"Failed to search vector store {vector_store_id}: {e}") - return [] - - # Run all searches in parallel using gather - search_tasks = [search_single_store(vid) for vid in response_file_search_tool.vector_store_ids] - all_results = await asyncio.gather(*search_tasks) - - # Flatten results - for results in all_results: - search_results.extend(results) - - # Convert search results to tool result format matching memory.py - # Format the results as interleaved content similar to memory.py - content_items = [] - content_items.append( - TextContentItem( - text=f"knowledge_search tool found {len(search_results)} chunks:\nBEGIN of knowledge_search tool results.\n" - ) - ) - - for i, result_item in enumerate(search_results): - chunk_text = result_item.content[0].text if result_item.content else "" - metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}" - if result_item.attributes: - metadata_text += f", attributes: {result_item.attributes}" - text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n" - content_items.append(TextContentItem(text=text_content)) - - content_items.append(TextContentItem(text="END of knowledge_search tool results.\n")) - content_items.append( - TextContentItem( - text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n', - ) - ) - - return ToolInvocationResult( - content=content_items, - metadata={ - "document_ids": [r.file_id for r in search_results], - "chunks": [r.content[0].text if r.content else "" for r in search_results], - "scores": [r.score for r in search_results], - }, - ) - - async def _execute_tool_call( - self, - tool_call: OpenAIChatCompletionToolCall, - ctx: ChatCompletionContext, - ) -> tuple[OpenAIResponseOutput | None, OpenAIMessageParam | None]: - from llama_stack.providers.utils.inference.prompt_adapter import ( - interleaved_content_as_str, - ) - - tool_call_id = tool_call.id - function = tool_call.function - tool_kwargs = json.loads(function.arguments) if function.arguments else {} - - if not function or not tool_call_id or not function.name: - return None, None - - error_exc = None - result = None - try: - if ctx.mcp_tool_to_server and function.name in ctx.mcp_tool_to_server: - from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool - - mcp_tool = ctx.mcp_tool_to_server[function.name] - result = await invoke_mcp_tool( - endpoint=mcp_tool.server_url, - headers=mcp_tool.headers or {}, - tool_name=function.name, - kwargs=tool_kwargs, - ) - elif function.name == "knowledge_search": - response_file_search_tool = next( - (t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), None - ) - if response_file_search_tool: - # Use vector_stores.search API instead of knowledge_search tool - # to support filters and ranking_options - query = tool_kwargs.get("query", "") - result = await self._execute_knowledge_search_via_vector_store( - query=query, - response_file_search_tool=response_file_search_tool, - ) - else: - result = await self.tool_runtime_api.invoke_tool( - tool_name=function.name, - kwargs=tool_kwargs, - ) - except Exception as e: - error_exc = e - - if function.name in ctx.mcp_tool_to_server: - from llama_stack.apis.agents.openai_responses import OpenAIResponseOutputMessageMCPCall - - message = OpenAIResponseOutputMessageMCPCall( - id=tool_call_id, - arguments=function.arguments, - name=function.name, - server_label=ctx.mcp_tool_to_server[function.name].server_label, - ) - if error_exc: - message.error = str(error_exc) - elif (result.error_code and result.error_code > 0) or result.error_message: - message.error = f"Error (code {result.error_code}): {result.error_message}" - elif result.content: - message.output = interleaved_content_as_str(result.content) - else: - if function.name == "web_search": - message = OpenAIResponseOutputMessageWebSearchToolCall( - id=tool_call_id, - status="completed", - ) - if error_exc or (result.error_code and result.error_code > 0) or result.error_message: - message.status = "failed" - elif function.name == "knowledge_search": - message = OpenAIResponseOutputMessageFileSearchToolCall( - id=tool_call_id, - queries=[tool_kwargs.get("query", "")], - status="completed", - ) - if "document_ids" in result.metadata: - message.results = [] - for i, doc_id in enumerate(result.metadata["document_ids"]): - text = result.metadata["chunks"][i] if "chunks" in result.metadata else None - score = result.metadata["scores"][i] if "scores" in result.metadata else None - message.results.append( - { - "file_id": doc_id, - "filename": doc_id, - "text": text, - "score": score, - } - ) - if error_exc or (result.error_code and result.error_code > 0) or result.error_message: - message.status = "failed" - else: - raise ValueError(f"Unknown tool {function.name} called") - - input_message = None - if result and result.content: - if isinstance(result.content, str): - content = result.content - elif isinstance(result.content, list): - from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem - - content = [] - for item in result.content: - if isinstance(item, TextContentItem): - part = OpenAIChatCompletionContentPartTextParam(text=item.text) - elif isinstance(item, ImageContentItem): - if item.image.data: - url = f"data:image;base64,{item.image.data}" - else: - url = item.image.url - part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url)) - else: - raise ValueError(f"Unknown result content type: {type(item)}") - content.append(part) - else: - raise ValueError(f"Unknown result content type: {type(result.content)}") - input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id) - else: - text = str(error_exc) - input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id) - - return message, input_message - - -def _is_function_tool_call( - tool_call: OpenAIChatCompletionToolCall, - tools: list[OpenAIResponseInputTool], -) -> bool: - if not tool_call.function: - return False - for t in tools: - if t.type == "function" and t.name == tool_call.function.name: - return True - return False diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 7a8d99b78..0b234d96c 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -191,7 +191,11 @@ class AgentPersistence: sessions = [] for value in values: try: - session_info = Session(**json.loads(value)) + data = json.loads(value) + if "turn_id" in data: + continue + + session_info = Session(**data) sessions.append(session_info) except Exception as e: log.error(f"Error parsing session info: {e}") diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/__init__.py b/llama_stack/providers/inline/agents/meta_reference/responses/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/inline/agents/meta_reference/responses/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py new file mode 100644 index 000000000..e528a4005 --- /dev/null +++ b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py @@ -0,0 +1,271 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import time +import uuid +from collections.abc import AsyncIterator + +from pydantic import BaseModel + +from llama_stack.apis.agents import Order +from llama_stack.apis.agents.openai_responses import ( + ListOpenAIResponseInputItem, + ListOpenAIResponseObject, + OpenAIDeleteResponseObject, + OpenAIResponseInput, + OpenAIResponseInputMessageContentText, + OpenAIResponseInputTool, + OpenAIResponseMessage, + OpenAIResponseObject, + OpenAIResponseObjectStream, + OpenAIResponseText, + OpenAIResponseTextFormat, +) +from llama_stack.apis.inference import ( + Inference, + OpenAISystemMessageParam, +) +from llama_stack.apis.tools import ToolGroups, ToolRuntime +from llama_stack.apis.vector_io import VectorIO +from llama_stack.log import get_logger +from llama_stack.providers.utils.responses.responses_store import ResponsesStore + +from .streaming import StreamingResponseOrchestrator +from .tool_executor import ToolExecutor +from .types import ChatCompletionContext +from .utils import ( + convert_response_input_to_chat_messages, + convert_response_text_to_chat_response_format, +) + +logger = get_logger(name=__name__, category="responses") + + +class OpenAIResponsePreviousResponseWithInputItems(BaseModel): + input_items: ListOpenAIResponseInputItem + response: OpenAIResponseObject + + +class OpenAIResponsesImpl: + def __init__( + self, + inference_api: Inference, + tool_groups_api: ToolGroups, + tool_runtime_api: ToolRuntime, + responses_store: ResponsesStore, + vector_io_api: VectorIO, # VectorIO + ): + self.inference_api = inference_api + self.tool_groups_api = tool_groups_api + self.tool_runtime_api = tool_runtime_api + self.responses_store = responses_store + self.vector_io_api = vector_io_api + self.tool_executor = ToolExecutor( + tool_groups_api=tool_groups_api, + tool_runtime_api=tool_runtime_api, + vector_io_api=vector_io_api, + ) + + async def _prepend_previous_response( + self, + input: str | list[OpenAIResponseInput], + previous_response_id: str | None = None, + ): + if previous_response_id: + previous_response_with_input = await self.responses_store.get_response_object(previous_response_id) + + # previous response input items + new_input_items = previous_response_with_input.input + + # previous response output items + new_input_items.extend(previous_response_with_input.output) + + # new input items from the current request + if isinstance(input, str): + new_input_items.append(OpenAIResponseMessage(content=input, role="user")) + else: + new_input_items.extend(input) + + input = new_input_items + + return input + + async def _prepend_instructions(self, messages, instructions): + if instructions: + messages.insert(0, OpenAISystemMessageParam(content=instructions)) + + async def get_openai_response( + self, + response_id: str, + ) -> OpenAIResponseObject: + response_with_input = await self.responses_store.get_response_object(response_id) + return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"}) + + async def list_openai_responses( + self, + after: str | None = None, + limit: int | None = 50, + model: str | None = None, + order: Order | None = Order.desc, + ) -> ListOpenAIResponseObject: + return await self.responses_store.list_responses(after, limit, model, order) + + async def list_openai_response_input_items( + self, + response_id: str, + after: str | None = None, + before: str | None = None, + include: list[str] | None = None, + limit: int | None = 20, + order: Order | None = Order.desc, + ) -> ListOpenAIResponseInputItem: + """List input items for a given OpenAI response. + + :param response_id: The ID of the response to retrieve input items for. + :param after: An item ID to list items after, used for pagination. + :param before: An item ID to list items before, used for pagination. + :param include: Additional fields to include in the response. + :param limit: A limit on the number of objects to be returned. + :param order: The order to return the input items in. + :returns: An ListOpenAIResponseInputItem. + """ + return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order) + + async def _store_response( + self, + response: OpenAIResponseObject, + input: str | list[OpenAIResponseInput], + ) -> None: + new_input_id = f"msg_{uuid.uuid4()}" + if isinstance(input, str): + # synthesize a message from the input string + input_content = OpenAIResponseInputMessageContentText(text=input) + input_content_item = OpenAIResponseMessage( + role="user", + content=[input_content], + id=new_input_id, + ) + input_items_data = [input_content_item] + else: + # we already have a list of messages + input_items_data = [] + for input_item in input: + if isinstance(input_item, OpenAIResponseMessage): + # These may or may not already have an id, so dump to dict, check for id, and add if missing + input_item_dict = input_item.model_dump() + if "id" not in input_item_dict: + input_item_dict["id"] = new_input_id + input_items_data.append(OpenAIResponseMessage(**input_item_dict)) + else: + input_items_data.append(input_item) + + await self.responses_store.store_response_object( + response_object=response, + input=input_items_data, + ) + + async def create_openai_response( + self, + input: str | list[OpenAIResponseInput], + model: str, + instructions: str | None = None, + previous_response_id: str | None = None, + store: bool | None = True, + stream: bool | None = False, + temperature: float | None = None, + text: OpenAIResponseText | None = None, + tools: list[OpenAIResponseInputTool] | None = None, + include: list[str] | None = None, + max_infer_iters: int | None = 10, + ): + stream = bool(stream) + text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text + + stream_gen = self._create_streaming_response( + input=input, + model=model, + instructions=instructions, + previous_response_id=previous_response_id, + store=store, + temperature=temperature, + text=text, + tools=tools, + max_infer_iters=max_infer_iters, + ) + + if stream: + return stream_gen + else: + response = None + async for stream_chunk in stream_gen: + if stream_chunk.type == "response.completed": + if response is not None: + raise ValueError("The response stream completed multiple times! Earlier response: {response}") + response = stream_chunk.response + # don't leave the generator half complete! + + if response is None: + raise ValueError("The response stream never completed") + return response + + async def _create_streaming_response( + self, + input: str | list[OpenAIResponseInput], + model: str, + instructions: str | None = None, + previous_response_id: str | None = None, + store: bool | None = True, + temperature: float | None = None, + text: OpenAIResponseText | None = None, + tools: list[OpenAIResponseInputTool] | None = None, + max_infer_iters: int | None = 10, + ) -> AsyncIterator[OpenAIResponseObjectStream]: + # Input preprocessing + input = await self._prepend_previous_response(input, previous_response_id) + messages = await convert_response_input_to_chat_messages(input) + await self._prepend_instructions(messages, instructions) + + # Structured outputs + response_format = await convert_response_text_to_chat_response_format(text) + + ctx = ChatCompletionContext( + model=model, + messages=messages, + response_tools=tools, + temperature=temperature, + response_format=response_format, + ) + + # Create orchestrator and delegate streaming logic + response_id = f"resp-{uuid.uuid4()}" + created_at = int(time.time()) + + orchestrator = StreamingResponseOrchestrator( + inference_api=self.inference_api, + ctx=ctx, + response_id=response_id, + created_at=created_at, + text=text, + max_infer_iters=max_infer_iters, + tool_executor=self.tool_executor, + ) + + # Stream the response + final_response = None + async for stream_chunk in orchestrator.create_response(): + if stream_chunk.type == "response.completed": + final_response = stream_chunk.response + yield stream_chunk + + # Store the response if requested + if store and final_response: + await self._store_response( + response=final_response, + input=input, + ) + + async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject: + return await self.responses_store.delete_response_object(response_id) diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py new file mode 100644 index 000000000..0879e978a --- /dev/null +++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -0,0 +1,634 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import uuid +from collections.abc import AsyncIterator +from typing import Any + +from llama_stack.apis.agents.openai_responses import ( + AllowedToolsFilter, + MCPListToolsTool, + OpenAIResponseContentPartOutputText, + OpenAIResponseInputTool, + OpenAIResponseInputToolMCP, + OpenAIResponseObject, + OpenAIResponseObjectStream, + OpenAIResponseObjectStreamResponseCompleted, + OpenAIResponseObjectStreamResponseContentPartAdded, + OpenAIResponseObjectStreamResponseContentPartDone, + OpenAIResponseObjectStreamResponseCreated, + OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta, + OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone, + OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta, + OpenAIResponseObjectStreamResponseMcpCallArgumentsDone, + OpenAIResponseObjectStreamResponseMcpListToolsCompleted, + OpenAIResponseObjectStreamResponseMcpListToolsInProgress, + OpenAIResponseObjectStreamResponseOutputItemAdded, + OpenAIResponseObjectStreamResponseOutputItemDone, + OpenAIResponseObjectStreamResponseOutputTextDelta, + OpenAIResponseOutput, + OpenAIResponseOutputMessageFunctionToolCall, + OpenAIResponseOutputMessageMCPListTools, + OpenAIResponseText, + WebSearchToolTypes, +) +from llama_stack.apis.inference import ( + Inference, + OpenAIAssistantMessageParam, + OpenAIChatCompletion, + OpenAIChatCompletionToolCall, + OpenAIChoice, +) +from llama_stack.log import get_logger + +from .types import ChatCompletionContext, ChatCompletionResult +from .utils import convert_chat_choice_to_response_message, is_function_tool_call + +logger = get_logger(name=__name__, category="responses") + + +class StreamingResponseOrchestrator: + def __init__( + self, + inference_api: Inference, + ctx: ChatCompletionContext, + response_id: str, + created_at: int, + text: OpenAIResponseText, + max_infer_iters: int, + tool_executor, # Will be the tool execution logic from the main class + ): + self.inference_api = inference_api + self.ctx = ctx + self.response_id = response_id + self.created_at = created_at + self.text = text + self.max_infer_iters = max_infer_iters + self.tool_executor = tool_executor + self.sequence_number = 0 + # Store MCP tool mapping that gets built during tool processing + self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {} + + async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]: + # Initialize output messages + output_messages: list[OpenAIResponseOutput] = [] + # Create initial response and emit response.created immediately + initial_response = OpenAIResponseObject( + created_at=self.created_at, + id=self.response_id, + model=self.ctx.model, + object="response", + status="in_progress", + output=output_messages.copy(), + text=self.text, + ) + + yield OpenAIResponseObjectStreamResponseCreated(response=initial_response) + + # Process all tools (including MCP tools) and emit streaming events + if self.ctx.response_tools: + async for stream_event in self._process_tools(self.ctx.response_tools, output_messages): + yield stream_event + + n_iter = 0 + messages = self.ctx.messages.copy() + + while True: + completion_result = await self.inference_api.openai_chat_completion( + model=self.ctx.model, + messages=messages, + tools=self.ctx.chat_tools, + stream=True, + temperature=self.ctx.temperature, + response_format=self.ctx.response_format, + ) + + # Process streaming chunks and build complete response + completion_result_data = None + async for stream_event_or_result in self._process_streaming_chunks(completion_result, output_messages): + if isinstance(stream_event_or_result, ChatCompletionResult): + completion_result_data = stream_event_or_result + else: + yield stream_event_or_result + if not completion_result_data: + raise ValueError("Streaming chunk processor failed to return completion data") + current_response = self._build_chat_completion(completion_result_data) + + function_tool_calls, non_function_tool_calls, next_turn_messages = self._separate_tool_calls( + current_response, messages + ) + + # Handle choices with no tool calls + for choice in current_response.choices: + if not (choice.message.tool_calls and self.ctx.response_tools): + output_messages.append(await convert_chat_choice_to_response_message(choice)) + + # Execute tool calls and coordinate results + async for stream_event in self._coordinate_tool_execution( + function_tool_calls, + non_function_tool_calls, + completion_result_data, + output_messages, + next_turn_messages, + ): + yield stream_event + + if not function_tool_calls and not non_function_tool_calls: + break + + if function_tool_calls: + logger.info("Exiting inference loop since there is a function (client-side) tool call") + break + + n_iter += 1 + if n_iter >= self.max_infer_iters: + logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}") + break + + messages = next_turn_messages + + # Create final response + final_response = OpenAIResponseObject( + created_at=self.created_at, + id=self.response_id, + model=self.ctx.model, + object="response", + status="completed", + text=self.text, + output=output_messages, + ) + + # Emit response.completed + yield OpenAIResponseObjectStreamResponseCompleted(response=final_response) + + def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list]: + """Separate tool calls into function and non-function categories.""" + function_tool_calls = [] + non_function_tool_calls = [] + next_turn_messages = messages.copy() + + for choice in current_response.choices: + next_turn_messages.append(choice.message) + + if choice.message.tool_calls and self.ctx.response_tools: + for tool_call in choice.message.tool_calls: + if is_function_tool_call(tool_call, self.ctx.response_tools): + function_tool_calls.append(tool_call) + else: + non_function_tool_calls.append(tool_call) + + return function_tool_calls, non_function_tool_calls, next_turn_messages + + async def _process_streaming_chunks( + self, completion_result, output_messages: list[OpenAIResponseOutput] + ) -> AsyncIterator[OpenAIResponseObjectStream | ChatCompletionResult]: + """Process streaming chunks and emit events, returning completion data.""" + # Initialize result tracking + chat_response_id = "" + chat_response_content = [] + chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {} + chunk_created = 0 + chunk_model = "" + chunk_finish_reason = "" + + # Create a placeholder message item for delta events + message_item_id = f"msg_{uuid.uuid4()}" + # Track tool call items for streaming events + tool_call_item_ids: dict[int, str] = {} + # Track content parts for streaming events + content_part_emitted = False + + async for chunk in completion_result: + chat_response_id = chunk.id + chunk_created = chunk.created + chunk_model = chunk.model + for chunk_choice in chunk.choices: + # Emit incremental text content as delta events + if chunk_choice.delta.content: + # Emit content_part.added event for first text chunk + if not content_part_emitted: + content_part_emitted = True + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseContentPartAdded( + response_id=self.response_id, + item_id=message_item_id, + part=OpenAIResponseContentPartOutputText( + text="", # Will be filled incrementally via text deltas + ), + sequence_number=self.sequence_number, + ) + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseOutputTextDelta( + content_index=0, + delta=chunk_choice.delta.content, + item_id=message_item_id, + output_index=0, + sequence_number=self.sequence_number, + ) + + # Collect content for final response + chat_response_content.append(chunk_choice.delta.content or "") + if chunk_choice.finish_reason: + chunk_finish_reason = chunk_choice.finish_reason + + # Aggregate tool call arguments across chunks + if chunk_choice.delta.tool_calls: + for tool_call in chunk_choice.delta.tool_calls: + response_tool_call = chat_response_tool_calls.get(tool_call.index, None) + # Create new tool call entry if this is the first chunk for this index + is_new_tool_call = response_tool_call is None + if is_new_tool_call: + tool_call_dict: dict[str, Any] = tool_call.model_dump() + tool_call_dict.pop("type", None) + response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict) + chat_response_tool_calls[tool_call.index] = response_tool_call + + # Create item ID for this tool call for streaming events + tool_call_item_id = f"fc_{uuid.uuid4()}" + tool_call_item_ids[tool_call.index] = tool_call_item_id + + # Emit output_item.added event for the new function call + self.sequence_number += 1 + function_call_item = OpenAIResponseOutputMessageFunctionToolCall( + arguments="", # Will be filled incrementally via delta events + call_id=tool_call.id or "", + name=tool_call.function.name if tool_call.function else "", + id=tool_call_item_id, + status="in_progress", + ) + yield OpenAIResponseObjectStreamResponseOutputItemAdded( + response_id=self.response_id, + item=function_call_item, + output_index=len(output_messages), + sequence_number=self.sequence_number, + ) + + # Stream tool call arguments as they arrive (differentiate between MCP and function calls) + if tool_call.function and tool_call.function.arguments: + tool_call_item_id = tool_call_item_ids[tool_call.index] + self.sequence_number += 1 + + # Check if this is an MCP tool call + is_mcp_tool = tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server + if is_mcp_tool: + # Emit MCP-specific argument delta event + yield OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta( + delta=tool_call.function.arguments, + item_id=tool_call_item_id, + output_index=len(output_messages), + sequence_number=self.sequence_number, + ) + else: + # Emit function call argument delta event + yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta( + delta=tool_call.function.arguments, + item_id=tool_call_item_id, + output_index=len(output_messages), + sequence_number=self.sequence_number, + ) + + # Accumulate arguments for final response (only for subsequent chunks) + if not is_new_tool_call: + response_tool_call.function.arguments = ( + response_tool_call.function.arguments or "" + ) + tool_call.function.arguments + + # Emit arguments.done events for completed tool calls (differentiate between MCP and function calls) + for tool_call_index in sorted(chat_response_tool_calls.keys()): + tool_call_item_id = tool_call_item_ids[tool_call_index] + final_arguments = chat_response_tool_calls[tool_call_index].function.arguments or "" + tool_call_name = chat_response_tool_calls[tool_call_index].function.name + + # Check if this is an MCP tool call + is_mcp_tool = tool_call_name and tool_call_name in self.mcp_tool_to_server + self.sequence_number += 1 + done_event_cls = ( + OpenAIResponseObjectStreamResponseMcpCallArgumentsDone + if is_mcp_tool + else OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone + ) + yield done_event_cls( + arguments=final_arguments, + item_id=tool_call_item_id, + output_index=len(output_messages), + sequence_number=self.sequence_number, + ) + + # Emit content_part.done event if text content was streamed (before content gets cleared) + if content_part_emitted: + final_text = "".join(chat_response_content) + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseContentPartDone( + response_id=self.response_id, + item_id=message_item_id, + part=OpenAIResponseContentPartOutputText( + text=final_text, + ), + sequence_number=self.sequence_number, + ) + + # Clear content when there are tool calls (OpenAI spec behavior) + if chat_response_tool_calls: + chat_response_content = [] + + yield ChatCompletionResult( + response_id=chat_response_id, + content=chat_response_content, + tool_calls=chat_response_tool_calls, + created=chunk_created, + model=chunk_model, + finish_reason=chunk_finish_reason, + message_item_id=message_item_id, + tool_call_item_ids=tool_call_item_ids, + content_part_emitted=content_part_emitted, + ) + + def _build_chat_completion(self, result: ChatCompletionResult) -> OpenAIChatCompletion: + """Build OpenAIChatCompletion from ChatCompletionResult.""" + # Convert collected chunks to complete response + if result.tool_calls: + tool_calls = [result.tool_calls[i] for i in sorted(result.tool_calls.keys())] + else: + tool_calls = None + + assistant_message = OpenAIAssistantMessageParam( + content=result.content_text, + tool_calls=tool_calls, + ) + return OpenAIChatCompletion( + id=result.response_id, + choices=[ + OpenAIChoice( + message=assistant_message, + finish_reason=result.finish_reason, + index=0, + ) + ], + created=result.created, + model=result.model, + ) + + async def _coordinate_tool_execution( + self, + function_tool_calls: list, + non_function_tool_calls: list, + completion_result_data: ChatCompletionResult, + output_messages: list[OpenAIResponseOutput], + next_turn_messages: list, + ) -> AsyncIterator[OpenAIResponseObjectStream]: + """Coordinate execution of both function and non-function tool calls.""" + # Execute non-function tool calls + for tool_call in non_function_tool_calls: + # Find the item_id for this tool call + matching_item_id = None + for index, item_id in completion_result_data.tool_call_item_ids.items(): + response_tool_call = completion_result_data.tool_calls.get(index) + if response_tool_call and response_tool_call.id == tool_call.id: + matching_item_id = item_id + break + + # Use a fallback item_id if not found + if not matching_item_id: + matching_item_id = f"tc_{uuid.uuid4()}" + + # Execute tool call with streaming + tool_call_log = None + tool_response_message = None + async for result in self.tool_executor.execute_tool_call( + tool_call, + self.ctx, + self.sequence_number, + len(output_messages), + matching_item_id, + self.mcp_tool_to_server, + ): + if result.stream_event: + # Forward streaming events + self.sequence_number = result.sequence_number + yield result.stream_event + + if result.final_output_message is not None: + tool_call_log = result.final_output_message + tool_response_message = result.final_input_message + self.sequence_number = result.sequence_number + + if tool_call_log: + output_messages.append(tool_call_log) + + # Emit output_item.done event for completed non-function tool call + if matching_item_id: + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseOutputItemDone( + response_id=self.response_id, + item=tool_call_log, + output_index=len(output_messages) - 1, + sequence_number=self.sequence_number, + ) + + if tool_response_message: + next_turn_messages.append(tool_response_message) + + # Execute function tool calls (client-side) + for tool_call in function_tool_calls: + # Find the item_id for this tool call from our tracking dictionary + matching_item_id = None + for index, item_id in completion_result_data.tool_call_item_ids.items(): + response_tool_call = completion_result_data.tool_calls.get(index) + if response_tool_call and response_tool_call.id == tool_call.id: + matching_item_id = item_id + break + + # Use existing item_id or create new one if not found + final_item_id = matching_item_id or f"fc_{uuid.uuid4()}" + + function_call_item = OpenAIResponseOutputMessageFunctionToolCall( + arguments=tool_call.function.arguments or "", + call_id=tool_call.id, + name=tool_call.function.name or "", + id=final_item_id, + status="completed", + ) + output_messages.append(function_call_item) + + # Emit output_item.done event for completed function call + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseOutputItemDone( + response_id=self.response_id, + item=function_call_item, + output_index=len(output_messages) - 1, + sequence_number=self.sequence_number, + ) + + async def _process_tools( + self, tools: list[OpenAIResponseInputTool], output_messages: list[OpenAIResponseOutput] + ) -> AsyncIterator[OpenAIResponseObjectStream]: + """Process all tools and emit appropriate streaming events.""" + from openai.types.chat import ChatCompletionToolParam + + from llama_stack.apis.tools import Tool + from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition + from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool + + def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam: + tool_def = ToolDefinition( + tool_name=tool_name, + description=tool.description, + parameters={ + param.name: ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, + ) + for param in tool.parameters + }, + ) + return convert_tooldef_to_openai_tool(tool_def) + + # Initialize chat_tools if not already set + if self.ctx.chat_tools is None: + self.ctx.chat_tools = [] + + for input_tool in tools: + if input_tool.type == "function": + self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump())) + elif input_tool.type in WebSearchToolTypes: + tool_name = "web_search" + # Need to access tool_groups_api from tool_executor + tool = await self.tool_executor.tool_groups_api.get_tool(tool_name) + if not tool: + raise ValueError(f"Tool {tool_name} not found") + self.ctx.chat_tools.append(make_openai_tool(tool_name, tool)) + elif input_tool.type == "file_search": + tool_name = "knowledge_search" + tool = await self.tool_executor.tool_groups_api.get_tool(tool_name) + if not tool: + raise ValueError(f"Tool {tool_name} not found") + self.ctx.chat_tools.append(make_openai_tool(tool_name, tool)) + elif input_tool.type == "mcp": + async for stream_event in self._process_mcp_tool(input_tool, output_messages): + yield stream_event + else: + raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}") + + async def _process_mcp_tool( + self, mcp_tool: OpenAIResponseInputToolMCP, output_messages: list[OpenAIResponseOutput] + ) -> AsyncIterator[OpenAIResponseObjectStream]: + """Process an MCP tool configuration and emit appropriate streaming events.""" + from llama_stack.providers.utils.tools.mcp import list_mcp_tools + + # Emit mcp_list_tools.in_progress + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseMcpListToolsInProgress( + sequence_number=self.sequence_number, + ) + + try: + # Parse allowed/never allowed tools + always_allowed = None + never_allowed = None + if mcp_tool.allowed_tools: + if isinstance(mcp_tool.allowed_tools, list): + always_allowed = mcp_tool.allowed_tools + elif isinstance(mcp_tool.allowed_tools, AllowedToolsFilter): + always_allowed = mcp_tool.allowed_tools.always + never_allowed = mcp_tool.allowed_tools.never + + # Call list_mcp_tools + tool_defs = await list_mcp_tools( + endpoint=mcp_tool.server_url, + headers=mcp_tool.headers or {}, + ) + + # Create the MCP list tools message + mcp_list_message = OpenAIResponseOutputMessageMCPListTools( + id=f"mcp_list_{uuid.uuid4()}", + server_label=mcp_tool.server_label, + tools=[], + ) + + # Process tools and update context + for t in tool_defs.data: + if never_allowed and t.name in never_allowed: + continue + if not always_allowed or t.name in always_allowed: + # Add to chat tools for inference + from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition + from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool + + tool_def = ToolDefinition( + tool_name=t.name, + description=t.description, + parameters={ + param.name: ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, + ) + for param in t.parameters + }, + ) + openai_tool = convert_tooldef_to_openai_tool(tool_def) + if self.ctx.chat_tools is None: + self.ctx.chat_tools = [] + self.ctx.chat_tools.append(openai_tool) + + # Add to MCP tool mapping + if t.name in self.mcp_tool_to_server: + raise ValueError(f"Duplicate tool name {t.name} found for server {mcp_tool.server_label}") + self.mcp_tool_to_server[t.name] = mcp_tool + + # Add to MCP list message + mcp_list_message.tools.append( + MCPListToolsTool( + name=t.name, + description=t.description, + input_schema={ + "type": "object", + "properties": { + p.name: { + "type": p.parameter_type, + "description": p.description, + } + for p in t.parameters + }, + "required": [p.name for p in t.parameters if p.required], + }, + ) + ) + + # Add the MCP list message to output + output_messages.append(mcp_list_message) + + # Emit output_item.added for the MCP list tools message + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseOutputItemAdded( + response_id=self.response_id, + item=mcp_list_message, + output_index=len(output_messages) - 1, + sequence_number=self.sequence_number, + ) + + # Emit mcp_list_tools.completed + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseMcpListToolsCompleted( + sequence_number=self.sequence_number, + ) + + # Emit output_item.done for the MCP list tools message + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseOutputItemDone( + response_id=self.response_id, + item=mcp_list_message, + output_index=len(output_messages) - 1, + sequence_number=self.sequence_number, + ) + + except Exception as e: + # TODO: Emit mcp_list_tools.failed event if needed + logger.exception(f"Failed to list MCP tools from {mcp_tool.server_url}: {e}") + raise diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py b/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py new file mode 100644 index 000000000..5b98b4f51 --- /dev/null +++ b/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py @@ -0,0 +1,379 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import json +from collections.abc import AsyncIterator + +from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseInputToolFileSearch, + OpenAIResponseInputToolMCP, + OpenAIResponseObjectStreamResponseMcpCallCompleted, + OpenAIResponseObjectStreamResponseMcpCallFailed, + OpenAIResponseObjectStreamResponseMcpCallInProgress, + OpenAIResponseObjectStreamResponseWebSearchCallCompleted, + OpenAIResponseObjectStreamResponseWebSearchCallInProgress, + OpenAIResponseObjectStreamResponseWebSearchCallSearching, + OpenAIResponseOutputMessageFileSearchToolCall, + OpenAIResponseOutputMessageFileSearchToolCallResults, + OpenAIResponseOutputMessageWebSearchToolCall, +) +from llama_stack.apis.common.content_types import ( + ImageContentItem, + TextContentItem, +) +from llama_stack.apis.inference import ( + OpenAIChatCompletionContentPartImageParam, + OpenAIChatCompletionContentPartTextParam, + OpenAIChatCompletionToolCall, + OpenAIImageURL, + OpenAIToolMessageParam, +) +from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime +from llama_stack.apis.vector_io import VectorIO +from llama_stack.log import get_logger + +from .types import ChatCompletionContext, ToolExecutionResult + +logger = get_logger(name=__name__, category="responses") + + +class ToolExecutor: + def __init__( + self, + tool_groups_api: ToolGroups, + tool_runtime_api: ToolRuntime, + vector_io_api: VectorIO, + ): + self.tool_groups_api = tool_groups_api + self.tool_runtime_api = tool_runtime_api + self.vector_io_api = vector_io_api + + async def execute_tool_call( + self, + tool_call: OpenAIChatCompletionToolCall, + ctx: ChatCompletionContext, + sequence_number: int, + output_index: int, + item_id: str, + mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, + ) -> AsyncIterator[ToolExecutionResult]: + tool_call_id = tool_call.id + function = tool_call.function + tool_kwargs = json.loads(function.arguments) if function.arguments else {} + + if not function or not tool_call_id or not function.name: + yield ToolExecutionResult(sequence_number=sequence_number) + return + + # Emit progress events for tool execution start + async for event_result in self._emit_progress_events( + function.name, ctx, sequence_number, output_index, item_id, mcp_tool_to_server + ): + sequence_number = event_result.sequence_number + yield event_result + + # Execute the actual tool call + error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server) + + # Emit completion events for tool execution + has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message)) + async for event_result in self._emit_completion_events( + function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server + ): + sequence_number = event_result.sequence_number + yield event_result + + # Build result messages from tool execution + output_message, input_message = await self._build_result_messages( + function, tool_call_id, tool_kwargs, ctx, error_exc, result, has_error, mcp_tool_to_server + ) + + # Yield the final result + yield ToolExecutionResult( + sequence_number=sequence_number, final_output_message=output_message, final_input_message=input_message + ) + + async def _execute_knowledge_search_via_vector_store( + self, + query: str, + response_file_search_tool: OpenAIResponseInputToolFileSearch, + ) -> ToolInvocationResult: + """Execute knowledge search using vector_stores.search API with filters support.""" + search_results = [] + + # Create search tasks for all vector stores + async def search_single_store(vector_store_id): + try: + search_response = await self.vector_io_api.openai_search_vector_store( + vector_store_id=vector_store_id, + query=query, + filters=response_file_search_tool.filters, + max_num_results=response_file_search_tool.max_num_results, + ranking_options=response_file_search_tool.ranking_options, + rewrite_query=False, + ) + return search_response.data + except Exception as e: + logger.warning(f"Failed to search vector store {vector_store_id}: {e}") + return [] + + # Run all searches in parallel using gather + search_tasks = [search_single_store(vid) for vid in response_file_search_tool.vector_store_ids] + all_results = await asyncio.gather(*search_tasks) + + # Flatten results + for results in all_results: + search_results.extend(results) + + # Convert search results to tool result format matching memory.py + # Format the results as interleaved content similar to memory.py + content_items = [] + content_items.append( + TextContentItem( + text=f"knowledge_search tool found {len(search_results)} chunks:\nBEGIN of knowledge_search tool results.\n" + ) + ) + + for i, result_item in enumerate(search_results): + chunk_text = result_item.content[0].text if result_item.content else "" + metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}" + if result_item.attributes: + metadata_text += f", attributes: {result_item.attributes}" + text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n" + content_items.append(TextContentItem(text=text_content)) + + content_items.append(TextContentItem(text="END of knowledge_search tool results.\n")) + content_items.append( + TextContentItem( + text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n', + ) + ) + + return ToolInvocationResult( + content=content_items, + metadata={ + "document_ids": [r.file_id for r in search_results], + "chunks": [r.content[0].text if r.content else "" for r in search_results], + "scores": [r.score for r in search_results], + }, + ) + + async def _emit_progress_events( + self, + function_name: str, + ctx: ChatCompletionContext, + sequence_number: int, + output_index: int, + item_id: str, + mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, + ) -> AsyncIterator[ToolExecutionResult]: + """Emit progress events for tool execution start.""" + # Emit in_progress event based on tool type (only for tools with specific streaming events) + progress_event = None + if mcp_tool_to_server and function_name in mcp_tool_to_server: + sequence_number += 1 + progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ) + elif function_name == "web_search": + sequence_number += 1 + progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ) + # Note: knowledge_search and other custom tools don't have specific streaming events in OpenAI spec + + if progress_event: + yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number) + + # For web search, emit searching event + if function_name == "web_search": + sequence_number += 1 + searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ) + yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number) + + async def _execute_tool( + self, + function_name: str, + tool_kwargs: dict, + ctx: ChatCompletionContext, + mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, + ) -> tuple[Exception | None, any]: + """Execute the tool and return error exception and result.""" + error_exc = None + result = None + + try: + if mcp_tool_to_server and function_name in mcp_tool_to_server: + from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool + + mcp_tool = mcp_tool_to_server[function_name] + result = await invoke_mcp_tool( + endpoint=mcp_tool.server_url, + headers=mcp_tool.headers or {}, + tool_name=function_name, + kwargs=tool_kwargs, + ) + elif function_name == "knowledge_search": + response_file_search_tool = next( + (t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), + None, + ) + if response_file_search_tool: + # Use vector_stores.search API instead of knowledge_search tool + # to support filters and ranking_options + query = tool_kwargs.get("query", "") + result = await self._execute_knowledge_search_via_vector_store( + query=query, + response_file_search_tool=response_file_search_tool, + ) + else: + result = await self.tool_runtime_api.invoke_tool( + tool_name=function_name, + kwargs=tool_kwargs, + ) + except Exception as e: + error_exc = e + + return error_exc, result + + async def _emit_completion_events( + self, + function_name: str, + ctx: ChatCompletionContext, + sequence_number: int, + output_index: int, + item_id: str, + has_error: bool, + mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, + ) -> AsyncIterator[ToolExecutionResult]: + """Emit completion or failure events for tool execution.""" + completion_event = None + + if mcp_tool_to_server and function_name in mcp_tool_to_server: + sequence_number += 1 + if has_error: + completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed( + sequence_number=sequence_number, + ) + else: + completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted( + sequence_number=sequence_number, + ) + elif function_name == "web_search": + sequence_number += 1 + completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ) + # Note: knowledge_search and other custom tools don't have specific completion events in OpenAI spec + + if completion_event: + yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number) + + async def _build_result_messages( + self, + function, + tool_call_id: str, + tool_kwargs: dict, + ctx: ChatCompletionContext, + error_exc: Exception | None, + result: any, + has_error: bool, + mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, + ) -> tuple[any, any]: + """Build output and input messages from tool execution results.""" + from llama_stack.providers.utils.inference.prompt_adapter import ( + interleaved_content_as_str, + ) + + # Build output message + if mcp_tool_to_server and function.name in mcp_tool_to_server: + from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseOutputMessageMCPCall, + ) + + message = OpenAIResponseOutputMessageMCPCall( + id=tool_call_id, + arguments=function.arguments, + name=function.name, + server_label=mcp_tool_to_server[function.name].server_label, + ) + if error_exc: + message.error = str(error_exc) + elif (result and result.error_code and result.error_code > 0) or (result and result.error_message): + message.error = f"Error (code {result.error_code}): {result.error_message}" + elif result and result.content: + message.output = interleaved_content_as_str(result.content) + else: + if function.name == "web_search": + message = OpenAIResponseOutputMessageWebSearchToolCall( + id=tool_call_id, + status="completed", + ) + if has_error: + message.status = "failed" + elif function.name == "knowledge_search": + message = OpenAIResponseOutputMessageFileSearchToolCall( + id=tool_call_id, + queries=[tool_kwargs.get("query", "")], + status="completed", + ) + if result and "document_ids" in result.metadata: + message.results = [] + for i, doc_id in enumerate(result.metadata["document_ids"]): + text = result.metadata["chunks"][i] if "chunks" in result.metadata else None + score = result.metadata["scores"][i] if "scores" in result.metadata else None + message.results.append( + OpenAIResponseOutputMessageFileSearchToolCallResults( + file_id=doc_id, + filename=doc_id, + text=text, + score=score, + attributes={}, + ) + ) + if has_error: + message.status = "failed" + else: + raise ValueError(f"Unknown tool {function.name} called") + + # Build input message + input_message = None + if result and result.content: + if isinstance(result.content, str): + content = result.content + elif isinstance(result.content, list): + content = [] + for item in result.content: + if isinstance(item, TextContentItem): + part = OpenAIChatCompletionContentPartTextParam(text=item.text) + elif isinstance(item, ImageContentItem): + if item.image.data: + url = f"data:image;base64,{item.image.data}" + else: + url = item.image.url + part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url)) + else: + raise ValueError(f"Unknown result content type: {type(item)}") + content.append(part) + else: + raise ValueError(f"Unknown result content type: {type(result.content)}") + input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id) + else: + text = str(error_exc) if error_exc else "Tool execution failed" + input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id) + + return message, input_message diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/types.py b/llama_stack/providers/inline/agents/meta_reference/responses/types.py new file mode 100644 index 000000000..89086c262 --- /dev/null +++ b/llama_stack/providers/inline/agents/meta_reference/responses/types.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from dataclasses import dataclass + +from openai.types.chat import ChatCompletionToolParam +from pydantic import BaseModel + +from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseInputTool, + OpenAIResponseObjectStream, + OpenAIResponseOutput, +) +from llama_stack.apis.inference import OpenAIChatCompletionToolCall, OpenAIMessageParam, OpenAIResponseFormatParam + + +class ToolExecutionResult(BaseModel): + """Result of streaming tool execution.""" + + stream_event: OpenAIResponseObjectStream | None = None + sequence_number: int + final_output_message: OpenAIResponseOutput | None = None + final_input_message: OpenAIMessageParam | None = None + + +@dataclass +class ChatCompletionResult: + """Result of processing streaming chat completion chunks.""" + + response_id: str + content: list[str] + tool_calls: dict[int, OpenAIChatCompletionToolCall] + created: int + model: str + finish_reason: str + message_item_id: str # For streaming events + tool_call_item_ids: dict[int, str] # For streaming events + content_part_emitted: bool # Tracking state + + @property + def content_text(self) -> str: + """Get joined content as string.""" + return "".join(self.content) + + @property + def has_tool_calls(self) -> bool: + """Check if there are any tool calls.""" + return bool(self.tool_calls) + + +class ChatCompletionContext(BaseModel): + model: str + messages: list[OpenAIMessageParam] + response_tools: list[OpenAIResponseInputTool] | None = None + chat_tools: list[ChatCompletionToolParam] | None = None + temperature: float | None + response_format: OpenAIResponseFormatParam diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py new file mode 100644 index 000000000..1507a55c8 --- /dev/null +++ b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py @@ -0,0 +1,169 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import uuid + +from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseInput, + OpenAIResponseInputFunctionToolCallOutput, + OpenAIResponseInputMessageContent, + OpenAIResponseInputMessageContentImage, + OpenAIResponseInputMessageContentText, + OpenAIResponseInputTool, + OpenAIResponseMessage, + OpenAIResponseOutputMessageContent, + OpenAIResponseOutputMessageContentOutputText, + OpenAIResponseOutputMessageFunctionToolCall, + OpenAIResponseText, +) +from llama_stack.apis.inference import ( + OpenAIAssistantMessageParam, + OpenAIChatCompletionContentPartImageParam, + OpenAIChatCompletionContentPartParam, + OpenAIChatCompletionContentPartTextParam, + OpenAIChatCompletionToolCall, + OpenAIChatCompletionToolCallFunction, + OpenAIChoice, + OpenAIDeveloperMessageParam, + OpenAIImageURL, + OpenAIJSONSchema, + OpenAIMessageParam, + OpenAIResponseFormatJSONObject, + OpenAIResponseFormatJSONSchema, + OpenAIResponseFormatParam, + OpenAIResponseFormatText, + OpenAISystemMessageParam, + OpenAIToolMessageParam, + OpenAIUserMessageParam, +) + + +async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage: + """Convert an OpenAI Chat Completion choice into an OpenAI Response output message.""" + output_content = "" + if isinstance(choice.message.content, str): + output_content = choice.message.content + elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam): + output_content = choice.message.content.text + else: + raise ValueError( + f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}" + ) + + return OpenAIResponseMessage( + id=f"msg_{uuid.uuid4()}", + content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)], + status="completed", + role="assistant", + ) + + +async def convert_response_content_to_chat_content( + content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]), +) -> str | list[OpenAIChatCompletionContentPartParam]: + """ + Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts. + + The content schemas of each API look similar, but are not exactly the same. + """ + if isinstance(content, str): + return content + + converted_parts = [] + for content_part in content: + if isinstance(content_part, OpenAIResponseInputMessageContentText): + converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text)) + elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText): + converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text)) + elif isinstance(content_part, OpenAIResponseInputMessageContentImage): + if content_part.image_url: + image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail) + converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url)) + elif isinstance(content_part, str): + converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part)) + else: + raise ValueError( + f"Llama Stack OpenAI Responses does not yet support content type '{type(content_part)}' in this context" + ) + return converted_parts + + +async def convert_response_input_to_chat_messages( + input: str | list[OpenAIResponseInput], +) -> list[OpenAIMessageParam]: + """ + Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages. + """ + messages: list[OpenAIMessageParam] = [] + if isinstance(input, list): + for input_item in input: + if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput): + messages.append( + OpenAIToolMessageParam( + content=input_item.output, + tool_call_id=input_item.call_id, + ) + ) + elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall): + tool_call = OpenAIChatCompletionToolCall( + index=0, + id=input_item.call_id, + function=OpenAIChatCompletionToolCallFunction( + name=input_item.name, + arguments=input_item.arguments, + ), + ) + messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call])) + else: + content = await convert_response_content_to_chat_content(input_item.content) + message_type = await get_message_type_by_role(input_item.role) + if message_type is None: + raise ValueError( + f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context" + ) + messages.append(message_type(content=content)) + else: + messages.append(OpenAIUserMessageParam(content=input)) + return messages + + +async def convert_response_text_to_chat_response_format( + text: OpenAIResponseText, +) -> OpenAIResponseFormatParam: + """ + Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format. + """ + if not text.format or text.format["type"] == "text": + return OpenAIResponseFormatText(type="text") + if text.format["type"] == "json_object": + return OpenAIResponseFormatJSONObject() + if text.format["type"] == "json_schema": + return OpenAIResponseFormatJSONSchema( + json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"]) + ) + raise ValueError(f"Unsupported text format: {text.format}") + + +async def get_message_type_by_role(role: str): + role_to_type = { + "user": OpenAIUserMessageParam, + "system": OpenAISystemMessageParam, + "assistant": OpenAIAssistantMessageParam, + "developer": OpenAIDeveloperMessageParam, + } + return role_to_type.get(role) + + +def is_function_tool_call( + tool_call: OpenAIChatCompletionToolCall, + tools: list[OpenAIResponseInputTool], +) -> bool: + if not tool_call.function: + return False + for t in tools: + if t.type == "function" and t.name == tool_call.function.name: + return True + return False diff --git a/llama_stack/providers/inline/batches/__init__.py b/llama_stack/providers/inline/batches/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/inline/batches/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/inline/batches/reference/__init__.py b/llama_stack/providers/inline/batches/reference/__init__.py new file mode 100644 index 000000000..a8ae92eb2 --- /dev/null +++ b/llama_stack/providers/inline/batches/reference/__init__.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.files import Files +from llama_stack.apis.inference import Inference +from llama_stack.apis.models import Models +from llama_stack.core.datatypes import AccessRule, Api +from llama_stack.providers.utils.kvstore import kvstore_impl + +from .batches import ReferenceBatchesImpl +from .config import ReferenceBatchesImplConfig + +__all__ = ["ReferenceBatchesImpl", "ReferenceBatchesImplConfig"] + + +async def get_provider_impl(config: ReferenceBatchesImplConfig, deps: dict[Api, Any], policy: list[AccessRule]): + kvstore = await kvstore_impl(config.kvstore) + inference_api: Inference | None = deps.get(Api.inference) + files_api: Files | None = deps.get(Api.files) + models_api: Models | None = deps.get(Api.models) + + if inference_api is None: + raise ValueError("Inference API is required but not provided in dependencies") + if files_api is None: + raise ValueError("Files API is required but not provided in dependencies") + if models_api is None: + raise ValueError("Models API is required but not provided in dependencies") + + impl = ReferenceBatchesImpl(config, inference_api, files_api, models_api, kvstore) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/batches/reference/batches.py b/llama_stack/providers/inline/batches/reference/batches.py new file mode 100644 index 000000000..1ff554e70 --- /dev/null +++ b/llama_stack/providers/inline/batches/reference/batches.py @@ -0,0 +1,580 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import itertools +import json +import time +import uuid +from io import BytesIO +from typing import Any, Literal + +from openai.types.batch import BatchError, Errors +from pydantic import BaseModel + +from llama_stack.apis.batches import Batches, BatchObject, ListBatchesResponse +from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError +from llama_stack.apis.files import Files, OpenAIFilePurpose +from llama_stack.apis.inference import ( + Inference, + OpenAIAssistantMessageParam, + OpenAIDeveloperMessageParam, + OpenAIMessageParam, + OpenAISystemMessageParam, + OpenAIToolMessageParam, + OpenAIUserMessageParam, +) +from llama_stack.apis.models import Models +from llama_stack.log import get_logger +from llama_stack.providers.utils.kvstore import KVStore + +from .config import ReferenceBatchesImplConfig + +BATCH_PREFIX = "batch:" + +logger = get_logger(__name__) + + +class AsyncBytesIO: + """ + Async-compatible BytesIO wrapper to allow async file-like operations. + + We use this when uploading files to the Files API, as it expects an + async file-like object. + """ + + def __init__(self, data: bytes): + self._buffer = BytesIO(data) + + async def read(self, n=-1): + return self._buffer.read(n) + + async def seek(self, pos, whence=0): + return self._buffer.seek(pos, whence) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._buffer.close() + + def __getattr__(self, name): + return getattr(self._buffer, name) + + +class BatchRequest(BaseModel): + line_num: int + custom_id: str + method: str + url: str + body: dict[str, Any] + + +def convert_to_openai_message_param(msg: dict[str, Any]) -> OpenAIMessageParam: + """Convert a message dictionary to OpenAIMessageParam based on role.""" + role = msg.get("role") + + if role == "user": + return OpenAIUserMessageParam(**msg) + elif role == "system": + return OpenAISystemMessageParam(**msg) + elif role == "assistant": + return OpenAIAssistantMessageParam(**msg) + elif role == "tool": + return OpenAIToolMessageParam(**msg) + elif role == "developer": + return OpenAIDeveloperMessageParam(**msg) + else: + raise ValueError(f"Unknown message role: {role}") + + +class ReferenceBatchesImpl(Batches): + """Reference implementation of the Batches API. + + This implementation processes batch files by making individual requests + to the inference API and generates output files with results. + """ + + def __init__( + self, + config: ReferenceBatchesImplConfig, + inference_api: Inference, + files_api: Files, + models_api: Models, + kvstore: KVStore, + ) -> None: + self.config = config + self.kvstore = kvstore + self.inference_api = inference_api + self.files_api = files_api + self.models_api = models_api + self._processing_tasks: dict[str, asyncio.Task] = {} + self._batch_semaphore = asyncio.Semaphore(config.max_concurrent_batches) + self._update_batch_lock = asyncio.Lock() + + # this is to allow tests to disable background processing + self.process_batches = True + + async def initialize(self) -> None: + # TODO: start background processing of existing tasks + pass + + async def shutdown(self) -> None: + """Shutdown the batches provider.""" + if self._processing_tasks: + # don't cancel tasks - just let them stop naturally on shutdown + # cancelling would mark batches as "cancelled" in the database + logger.info(f"Shutdown initiated with {len(self._processing_tasks)} active batch processing tasks") + + # TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions + async def create_batch( + self, + input_file_id: str, + endpoint: str, + completion_window: Literal["24h"], + metadata: dict[str, str] | None = None, + ) -> BatchObject: + """ + Create a new batch for processing multiple API requests. + + Error handling by levels - + 0. Input param handling, results in 40x errors before processing, e.g. + - Wrong completion_window + - Invalid metadata types + - Unknown endpoint + -> no batch created + 1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g. + - input_file_id missing + - invalid json in file + - missing custom_id, method, url, body + - invalid model + - streaming + -> batch created, validation sends to failed status + 2. Processing errors, result in error_file_id entries, e.g. + - Any error returned from inference endpoint + -> batch created, goes to completed status + """ + + # TODO: set expiration time for garbage collection + + if endpoint not in ["/v1/chat/completions"]: + raise ValueError( + f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint", + ) + + if completion_window != "24h": + raise ValueError( + f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window", + ) + + batch_id = f"batch_{uuid.uuid4().hex[:16]}" + current_time = int(time.time()) + + batch = BatchObject( + id=batch_id, + object="batch", + endpoint=endpoint, + input_file_id=input_file_id, + completion_window=completion_window, + status="validating", + created_at=current_time, + metadata=metadata, + ) + + await self.kvstore.set(f"batch:{batch_id}", batch.to_json()) + + if self.process_batches: + task = asyncio.create_task(self._process_batch(batch_id)) + self._processing_tasks[batch_id] = task + + return batch + + async def cancel_batch(self, batch_id: str) -> BatchObject: + """Cancel a batch that is in progress.""" + batch = await self.retrieve_batch(batch_id) + + if batch.status in ["cancelled", "cancelling"]: + return batch + + if batch.status in ["completed", "failed", "expired"]: + raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'") + + await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time())) + + if batch_id in self._processing_tasks: + self._processing_tasks[batch_id].cancel() + # note: task removal and status="cancelled" handled in finally block of _process_batch + + return await self.retrieve_batch(batch_id) + + async def list_batches( + self, + after: str | None = None, + limit: int = 20, + ) -> ListBatchesResponse: + """ + List all batches, eventually only for the current user. + + With no notion of user, we return all batches. + """ + batch_values = await self.kvstore.values_in_range("batch:", "batch:\xff") + + batches = [] + for batch_data in batch_values: + if batch_data: + batches.append(BatchObject.model_validate_json(batch_data)) + + batches.sort(key=lambda b: b.created_at, reverse=True) + + start_idx = 0 + if after: + for i, batch in enumerate(batches): + if batch.id == after: + start_idx = i + 1 + break + + page_batches = batches[start_idx : start_idx + limit] + has_more = (start_idx + limit) < len(batches) + + first_id = page_batches[0].id if page_batches else None + last_id = page_batches[-1].id if page_batches else None + + return ListBatchesResponse( + data=page_batches, + first_id=first_id, + last_id=last_id, + has_more=has_more, + ) + + async def retrieve_batch(self, batch_id: str) -> BatchObject: + """Retrieve information about a specific batch.""" + batch_data = await self.kvstore.get(f"batch:{batch_id}") + if not batch_data: + raise ResourceNotFoundError(batch_id, "Batch", "batches.list()") + + return BatchObject.model_validate_json(batch_data) + + async def _update_batch(self, batch_id: str, **updates) -> None: + """Update batch fields in kvstore.""" + async with self._update_batch_lock: + try: + batch = await self.retrieve_batch(batch_id) + + # batch processing is async. once cancelling, only allow "cancelled" status updates + if batch.status == "cancelling" and updates.get("status") != "cancelled": + logger.info( + f"Skipping status update for cancelled batch {batch_id}: attempted {updates.get('status')}" + ) + return + + if "errors" in updates: + updates["errors"] = updates["errors"].model_dump() + + batch_dict = batch.model_dump() + batch_dict.update(updates) + + await self.kvstore.set(f"batch:{batch_id}", json.dumps(batch_dict)) + except Exception as e: + logger.error(f"Failed to update batch {batch_id}: {e}") + + async def _validate_input(self, batch: BatchObject) -> tuple[list[BatchError], list[BatchRequest]]: + """ + Read & validate input, return errors and valid input. + + Validation of + - input_file_id existance + - valid json + - custom_id, method, url, body presence and valid + - no streaming + """ + requests: list[BatchRequest] = [] + errors: list[BatchError] = [] + try: + await self.files_api.openai_retrieve_file(batch.input_file_id) + except Exception: + errors.append( + BatchError( + code="invalid_request", + line=None, + message=f"Cannot find file {batch.input_file_id}.", + param="input_file_id", + ) + ) + return errors, requests + + # TODO(SECURITY): do something about large files + file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id) + file_content = file_content_response.body.decode("utf-8") + for line_num, line in enumerate(file_content.strip().split("\n"), 1): + if line.strip(): # skip empty lines + try: + request = json.loads(line) + + if not isinstance(request, dict): + errors.append( + BatchError( + code="invalid_request", + line=line_num, + message="Each line must be a JSON dictionary object", + ) + ) + continue + + valid = True + + for param, expected_type, type_string in [ + ("custom_id", str, "string"), + ("method", str, "string"), + ("url", str, "string"), + ("body", dict, "JSON dictionary object"), + ]: + if param not in request: + errors.append( + BatchError( + code="missing_required_parameter", + line=line_num, + message=f"Missing required parameter: {param}", + param=param, + ) + ) + valid = False + elif not isinstance(request[param], expected_type): + param_name = "URL" if param == "url" else param.capitalize() + errors.append( + BatchError( + code="invalid_request", + line=line_num, + message=f"{param_name} must be a {type_string}", + param=param, + ) + ) + valid = False + + if (url := request.get("url")) and isinstance(url, str) and url != batch.endpoint: + errors.append( + BatchError( + code="invalid_url", + line=line_num, + message="URL provided for this request does not match the batch endpoint", + param="url", + ) + ) + valid = False + + if (body := request.get("body")) and isinstance(body, dict): + if body.get("stream", False): + errors.append( + BatchError( + code="streaming_unsupported", + line=line_num, + message="Streaming is not supported in batch processing", + param="body.stream", + ) + ) + valid = False + + for param, expected_type, type_string in [ + ("model", str, "a string"), + # messages is specific to /v1/chat/completions + # we could skip validating messages here and let inference fail. however, + # that would be a very expensive way to find out messages is wrong. + ("messages", list, "an array"), # TODO: allow messages to be a string? + ]: + if param not in body: + errors.append( + BatchError( + code="invalid_request", + line=line_num, + message=f"{param.capitalize()} parameter is required", + param=f"body.{param}", + ) + ) + valid = False + elif not isinstance(body[param], expected_type): + errors.append( + BatchError( + code="invalid_request", + line=line_num, + message=f"{param.capitalize()} must be {type_string}", + param=f"body.{param}", + ) + ) + valid = False + + if "model" in body and isinstance(body["model"], str): + try: + await self.models_api.get_model(body["model"]) + except Exception: + errors.append( + BatchError( + code="model_not_found", + line=line_num, + message=f"Model '{body['model']}' does not exist or is not supported", + param="body.model", + ) + ) + valid = False + + if valid: + assert isinstance(url, str), "URL must be a string" # for mypy + assert isinstance(body, dict), "Body must be a dictionary" # for mypy + requests.append( + BatchRequest( + line_num=line_num, + url=url, + method=request["method"], + custom_id=request["custom_id"], + body=body, + ), + ) + except json.JSONDecodeError: + errors.append( + BatchError( + code="invalid_json_line", + line=line_num, + message="This line is not parseable as valid JSON.", + ) + ) + + return errors, requests + + async def _process_batch(self, batch_id: str) -> None: + """Background task to process a batch of requests.""" + try: + logger.info(f"Starting batch processing for {batch_id}") + async with self._batch_semaphore: # semaphore to limit concurrency + logger.info(f"Acquired semaphore for batch {batch_id}") + await self._process_batch_impl(batch_id) + except asyncio.CancelledError: + logger.info(f"Batch processing cancelled for {batch_id}") + await self._update_batch(batch_id, status="cancelled", cancelled_at=int(time.time())) + except Exception as e: + logger.error(f"Batch processing failed for {batch_id}: {e}") + await self._update_batch( + batch_id, + status="failed", + failed_at=int(time.time()), + errors=Errors(data=[BatchError(code="internal_error", message=str(e))]), + ) + finally: + self._processing_tasks.pop(batch_id, None) + + async def _process_batch_impl(self, batch_id: str) -> None: + """Implementation of batch processing logic.""" + errors: list[BatchError] = [] + batch = await self.retrieve_batch(batch_id) + + errors, requests = await self._validate_input(batch) + if errors: + await self._update_batch(batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors)) + logger.info(f"Batch validation failed for {batch_id} with {len(errors)} errors") + return + + logger.info(f"Processing {len(requests)} requests for batch {batch_id}") + + total_requests = len(requests) + await self._update_batch( + batch_id, + status="in_progress", + request_counts={"total": total_requests, "completed": 0, "failed": 0}, + ) + + error_results = [] + success_results = [] + completed_count = 0 + failed_count = 0 + + for chunk in itertools.batched(requests, self.config.max_concurrent_requests_per_batch): + # we use a TaskGroup to ensure all process-single-request tasks are canceled when process-batch is cancelled + async with asyncio.TaskGroup() as tg: + chunk_tasks = [tg.create_task(self._process_single_request(batch_id, request)) for request in chunk] + + chunk_results = await asyncio.gather(*chunk_tasks, return_exceptions=True) + + for result in chunk_results: + if isinstance(result, dict) and result.get("error") is not None: # error response from inference + failed_count += 1 + error_results.append(result) + elif isinstance(result, dict) and result.get("response") is not None: # successful inference + completed_count += 1 + success_results.append(result) + else: # unexpected result + failed_count += 1 + errors.append(BatchError(code="internal_error", message=f"Unexpected result: {result}")) + + await self._update_batch( + batch_id, + request_counts={"total": total_requests, "completed": completed_count, "failed": failed_count}, + ) + + if errors: + await self._update_batch( + batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors) + ) + return + + try: + output_file_id = await self._create_output_file(batch_id, success_results, "success") + await self._update_batch(batch_id, output_file_id=output_file_id) + + error_file_id = await self._create_output_file(batch_id, error_results, "error") + await self._update_batch(batch_id, error_file_id=error_file_id) + + await self._update_batch(batch_id, status="completed", completed_at=int(time.time())) + + logger.info( + f"Batch processing completed for {batch_id}: {completed_count} completed, {failed_count} failed" + ) + except Exception as e: + # note: errors is empty at this point, so we don't lose anything by ignoring it + await self._update_batch( + batch_id, + status="failed", + failed_at=int(time.time()), + errors=Errors(data=[BatchError(code="output_failed", message=str(e))]), + ) + + async def _process_single_request(self, batch_id: str, request: BatchRequest) -> dict: + """Process a single request from the batch.""" + request_id = f"batch_req_{batch_id}_{request.line_num}" + + try: + # TODO(SECURITY): review body for security issues + request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]] + chat_response = await self.inference_api.openai_chat_completion(**request.body) + + # this is for mypy, we don't allow streaming so we'll get the right type + assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method" + return { + "id": request_id, + "custom_id": request.custom_id, + "response": { + "status_code": 200, + "request_id": request_id, # TODO: should this be different? + "body": chat_response.model_dump_json(), + }, + } + except Exception as e: + logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}") + return { + "id": request_id, + "custom_id": request.custom_id, + "error": {"type": "request_failed", "message": str(e)}, + } + + async def _create_output_file(self, batch_id: str, results: list[dict], file_type: str) -> str: + """ + Create an output file with batch results. + + This function filters results based on the specified file_type + and uploads the file to the Files API. + """ + output_lines = [json.dumps(result) for result in results] + + with AsyncBytesIO("\n".join(output_lines).encode("utf-8")) as file_buffer: + file_buffer.filename = f"{batch_id}_{file_type}.jsonl" + uploaded_file = await self.files_api.openai_upload_file(file=file_buffer, purpose=OpenAIFilePurpose.BATCH) + return uploaded_file.id diff --git a/llama_stack/providers/inline/batches/reference/config.py b/llama_stack/providers/inline/batches/reference/config.py new file mode 100644 index 000000000..d8d06868b --- /dev/null +++ b/llama_stack/providers/inline/batches/reference/config.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel, Field + +from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig + + +class ReferenceBatchesImplConfig(BaseModel): + """Configuration for the Reference Batches implementation.""" + + kvstore: KVStoreConfig = Field( + description="Configuration for the key-value store backend.", + ) + + max_concurrent_batches: int = Field( + default=1, + description="Maximum number of concurrent batches to process simultaneously.", + ge=1, + ) + + max_concurrent_requests_per_batch: int = Field( + default=10, + description="Maximum number of concurrent requests to process per batch.", + ge=1, + ) + + # TODO: add a max requests per second rate limiter + + @classmethod + def sample_run_config(cls, __distro_dir__: str) -> dict: + return { + "kvstore": SqliteKVStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, + db_name="batches.db", + ), + } diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index f83c39a6a..bae744010 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -22,7 +22,7 @@ from llama_stack.apis.safety import ( SafetyViolation, ViolationLevel, ) -from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults, OpenAICategories +from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults from llama_stack.apis.shields import Shield from llama_stack.core.datatypes import Api from llama_stack.models.llama.datatypes import Role @@ -72,30 +72,6 @@ SAFETY_CATEGORIES_TO_CODE_MAP = { } SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()} -OPENAI_TO_LLAMA_CATEGORIES_MAP = { - OpenAICategories.VIOLENCE: [CAT_VIOLENT_CRIMES], - OpenAICategories.VIOLENCE_GRAPHIC: [CAT_VIOLENT_CRIMES], - OpenAICategories.HARRASMENT: [CAT_CHILD_EXPLOITATION], - OpenAICategories.HARRASMENT_THREATENING: [CAT_VIOLENT_CRIMES, CAT_CHILD_EXPLOITATION], - OpenAICategories.HATE: [CAT_HATE], - OpenAICategories.HATE_THREATENING: [CAT_HATE, CAT_VIOLENT_CRIMES], - OpenAICategories.ILLICIT: [CAT_NON_VIOLENT_CRIMES], - OpenAICategories.ILLICIT_VIOLENT: [CAT_VIOLENT_CRIMES, CAT_INDISCRIMINATE_WEAPONS], - OpenAICategories.SEXUAL: [CAT_SEX_CRIMES, CAT_SEXUAL_CONTENT], - OpenAICategories.SEXUAL_MINORS: [CAT_CHILD_EXPLOITATION], - OpenAICategories.SELF_HARM: [CAT_SELF_HARM], - OpenAICategories.SELF_HARM_INTENT: [CAT_SELF_HARM], - OpenAICategories.SELF_HARM_INSTRUCTIONS: [CAT_SELF_HARM, CAT_SPECIALIZED_ADVICE], - # These are custom categories that are not in the OpenAI moderation categories - "custom/defamation": [CAT_DEFAMATION], - "custom/specialized_advice": [CAT_SPECIALIZED_ADVICE], - "custom/privacy_violation": [CAT_PRIVACY], - "custom/intellectual_property": [CAT_INTELLECTUAL_PROPERTY], - "custom/weapons": [CAT_INDISCRIMINATE_WEAPONS], - "custom/elections": [CAT_ELECTIONS], - "custom/code_interpreter_abuse": [CAT_CODE_INTERPRETER_ABUSE], -} - DEFAULT_LG_V3_SAFETY_CATEGORIES = [ CAT_VIOLENT_CRIMES, @@ -424,9 +400,9 @@ class LlamaGuardShield: ModerationObject with appropriate configuration """ # Set default values for safe case - categories = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), False) - category_scores = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), 1.0) - category_applied_input_types = {key: [] for key in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()} + categories = dict.fromkeys(SAFETY_CATEGORIES_TO_CODE_MAP.keys(), False) + category_scores = dict.fromkeys(SAFETY_CATEGORIES_TO_CODE_MAP.keys(), 1.0) + category_applied_input_types = {key: [] for key in SAFETY_CATEGORIES_TO_CODE_MAP.keys()} flagged = False user_message = None metadata = {} @@ -453,19 +429,15 @@ class LlamaGuardShield: ], ) - # Get OpenAI categories for the unsafe codes - openai_categories = [] - for code in unsafe_code_list: - llama_guard_category = SAFETY_CODE_TO_CATEGORIES_MAP[code] - openai_categories.extend( - k for k, v_l in OPENAI_TO_LLAMA_CATEGORIES_MAP.items() if llama_guard_category in v_l - ) + llama_guard_category = [SAFETY_CODE_TO_CATEGORIES_MAP[code] for code in unsafe_code_list] # Update categories for unsafe content - categories = {k: k in openai_categories for k in OPENAI_TO_LLAMA_CATEGORIES_MAP} - category_scores = {k: 1.0 if k in openai_categories else 0.0 for k in OPENAI_TO_LLAMA_CATEGORIES_MAP} + categories = {k: k in llama_guard_category for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys()} + category_scores = { + k: 1.0 if k in llama_guard_category else 0.0 for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys() + } category_applied_input_types = { - k: ["text"] if k in openai_categories else [] for k in OPENAI_TO_LLAMA_CATEGORIES_MAP + k: ["text"] if k in llama_guard_category else [] for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys() } flagged = True user_message = CANNED_RESPONSE_TEXT diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index e11ec5cf5..c760f0fd1 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -18,6 +18,7 @@ from llama_stack.apis.safety import ( ShieldStore, ViolationLevel, ) +from llama_stack.apis.safety.safety import ModerationObject from llama_stack.apis.shields import Shield from llama_stack.core.utils.model_utils import model_local_dir from llama_stack.providers.datatypes import ShieldsProtocolPrivate @@ -64,6 +65,9 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): return await self.shield.run(messages) + async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject: + raise NotImplementedError("run_moderation is not implemented for Prompt Guard") + class PromptGuardShield: def __init__( diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index 5a063592c..af61da59b 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -33,6 +33,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( + ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex, ) @@ -128,11 +129,12 @@ class FaissIndex(EmbeddingIndex): # Save updated index await self._save_index() - async def delete_chunk(self, chunk_id: str) -> None: - if chunk_id not in self.chunk_ids: + async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None: + chunk_ids = [c.chunk_id for c in chunks_for_deletion] + if not set(chunk_ids).issubset(self.chunk_ids): return - async with self.chunk_id_lock: + def remove_chunk(chunk_id: str): index = self.chunk_ids.index(chunk_id) self.index.remove_ids(np.array([index])) @@ -146,6 +148,10 @@ class FaissIndex(EmbeddingIndex): self.chunk_by_index = new_chunk_by_index self.chunk_ids.pop(index) + async with self.chunk_id_lock: + for chunk_id in chunk_ids: + remove_chunk(chunk_id) + await self._save_index() async def query_vector( @@ -297,8 +303,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr return await index.query_chunks(query, params) - async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: - """Delete a chunk from a faiss index""" + async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: + """Delete chunks from a faiss index""" faiss_index = self.cache[store_id].index - for chunk_id in chunk_ids: - await faiss_index.delete_chunk(chunk_id) + await faiss_index.delete_chunks(chunks_for_deletion) diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index 1fff7b484..cc1982f3b 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -31,6 +31,7 @@ from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIV from llama_stack.providers.utils.memory.vector_store import ( RERANKER_TYPE_RRF, RERANKER_TYPE_WEIGHTED, + ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex, ) @@ -426,34 +427,36 @@ class SQLiteVecIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) - async def delete_chunk(self, chunk_id: str) -> None: + async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None: """Remove a chunk from the SQLite vector store.""" + chunk_ids = [c.chunk_id for c in chunks_for_deletion] - def _delete_chunk(): + def _delete_chunks(): connection = _create_sqlite_connection(self.db_path) cur = connection.cursor() try: cur.execute("BEGIN TRANSACTION") # Delete from metadata table - cur.execute(f"DELETE FROM {self.metadata_table} WHERE id = ?", (chunk_id,)) + placeholders = ",".join("?" * len(chunk_ids)) + cur.execute(f"DELETE FROM {self.metadata_table} WHERE id IN ({placeholders})", chunk_ids) # Delete from vector table - cur.execute(f"DELETE FROM {self.vector_table} WHERE id = ?", (chunk_id,)) + cur.execute(f"DELETE FROM {self.vector_table} WHERE id IN ({placeholders})", chunk_ids) # Delete from FTS table - cur.execute(f"DELETE FROM {self.fts_table} WHERE id = ?", (chunk_id,)) + cur.execute(f"DELETE FROM {self.fts_table} WHERE id IN ({placeholders})", chunk_ids) connection.commit() except Exception as e: connection.rollback() - logger.error(f"Error deleting chunk {chunk_id}: {e}") + logger.error(f"Error deleting chunks: {e}") raise finally: cur.close() connection.close() - await asyncio.to_thread(_delete_chunk) + await asyncio.to_thread(_delete_chunks) class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): @@ -551,12 +554,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc raise VectorStoreNotFoundError(vector_db_id) return await index.query_chunks(query, params) - async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: - """Delete a chunk from a sqlite_vec index.""" + async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: + """Delete chunks from a sqlite_vec index.""" index = await self._get_and_cache_vector_db_index(store_id) if not index: raise VectorStoreNotFoundError(store_id) - for chunk_id in chunk_ids: - # Use the index's delete_chunk method - await index.index.delete_chunk(chunk_id) + await index.index.delete_chunks(chunks_for_deletion) diff --git a/llama_stack/providers/registry/batches.py b/llama_stack/providers/registry/batches.py new file mode 100644 index 000000000..de7886efb --- /dev/null +++ b/llama_stack/providers/registry/batches.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec + + +def available_providers() -> list[ProviderSpec]: + return [ + InlineProviderSpec( + api=Api.batches, + provider_type="inline::reference", + pip_packages=["openai"], + module="llama_stack.providers.inline.batches.reference", + config_class="llama_stack.providers.inline.batches.reference.config.ReferenceBatchesImplConfig", + api_dependencies=[ + Api.inference, + Api.files, + Api.models, + ], + description="Reference implementation of batches API with KVStore persistence.", + ), + ] diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index ed170b508..70148eb15 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -342,6 +342,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti """, ), api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], ), InlineProviderSpec( api=Api.vector_io, @@ -350,6 +351,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti module="llama_stack.providers.inline.vector_io.chroma", config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig", api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], description=""" [Chroma](https://www.trychroma.com/) is an inline and remote vector database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database. @@ -464,6 +466,7 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more """, ), api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], ), InlineProviderSpec( api=Api.vector_io, @@ -731,6 +734,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi """, ), api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], ), InlineProviderSpec( api=Api.vector_io, diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index ca4c7b578..bd86f7238 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -235,6 +235,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv llama_model = self.get_llama_model(request.model) if isinstance(request, ChatCompletionRequest): + # TODO: tools are never added to the request, so we need to add them here if media_present or not llama_model: input_dict["messages"] = [ await convert_message_to_openai_dict(m, download=True) for m in request.messages @@ -378,6 +379,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv # Fireworks chat completions OpenAI-compatible API does not support # tool calls properly. llama_model = self.get_llama_model(model_obj.provider_resource_id) + if llama_model: return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion( self, @@ -431,4 +433,5 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv user=user, ) + logger.debug(f"fireworks params: {params}") return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 26b4dec76..a93421536 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -457,9 +457,6 @@ class OllamaInferenceAdapter( user: str | None = None, ) -> OpenAIEmbeddingsResponse: model_obj = await self._get_model(model) - if model_obj.model_type != ModelType.embedding: - raise ValueError(f"Model {model} is not an embedding model") - if model_obj.provider_resource_id is None: raise ValueError(f"Model {model} has no provider_resource_id set") diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index a5bb079ef..323831845 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -308,9 +308,7 @@ class TGIAdapter(_HfAdapter): if not config.url: raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.") log.info(f"Initializing TGI client with url={config.url}") - self.client = AsyncInferenceClient( - model=config.url, - ) + self.client = AsyncInferenceClient(model=config.url, provider="hf-inference") endpoint_info = await self.client.get_endpoint_info() self.max_tokens = endpoint_info["max_total_tokens"] self.model_id = endpoint_info["model_id"] diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 75226a560..9c7a7732a 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -26,6 +26,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( + ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex, ) @@ -146,8 +147,10 @@ class ChromaIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) - async def delete_chunk(self, chunk_id: str) -> None: - await maybe_await(self.collection.delete([chunk_id])) + async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None: + """Delete a single chunk from the Chroma collection by its ID.""" + ids = [f"{chunk.document_id}:{chunk.chunk_id}" for chunk in chunks_for_deletion] + await maybe_await(self.collection.delete(ids=ids)) async def query_hybrid( self, @@ -175,6 +178,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP self.cache = {} self.kvstore: KVStore | None = None self.vector_db_store = None + self.files_api = files_api async def initialize(self) -> None: self.kvstore = await kvstore_impl(self.config.kvstore) @@ -258,5 +262,10 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP self.cache[vector_db_id] = index return index - async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") + async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: + """Delete chunks from a Chroma vector store.""" + index = await self._get_and_cache_vector_db_index(store_id) + if not index: + raise ValueError(f"Vector DB {store_id} not found") + + await index.index.delete_chunks(chunks_for_deletion) diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index b09edb65c..0eaae81b3 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -28,6 +28,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( RERANKER_TYPE_WEIGHTED, + ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex, ) @@ -287,14 +288,17 @@ class MilvusIndex(EmbeddingIndex): return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores) - async def delete_chunk(self, chunk_id: str) -> None: + async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None: """Remove a chunk from the Milvus collection.""" + chunk_ids = [c.chunk_id for c in chunks_for_deletion] try: + # Use IN clause with square brackets and single quotes for VARCHAR field + chunk_ids_str = ", ".join(f"'{chunk_id}'" for chunk_id in chunk_ids) await asyncio.to_thread( - self.client.delete, collection_name=self.collection_name, filter=f'chunk_id == "{chunk_id}"' + self.client.delete, collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]" ) except Exception as e: - logger.error(f"Error deleting chunk {chunk_id} from Milvus collection {self.collection_name}: {e}") + logger.error(f"Error deleting chunks from Milvus collection {self.collection_name}: {e}") raise @@ -420,12 +424,10 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP return await index.query_chunks(query, params) - async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: """Delete a chunk from a milvus vector store.""" index = await self._get_and_cache_vector_db_index(store_id) if not index: raise VectorStoreNotFoundError(store_id) - for chunk_id in chunk_ids: - # Use the index's delete_chunk method - await index.index.delete_chunk(chunk_id) + await index.index.delete_chunks(chunks_for_deletion) diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index b1645ac5a..d2a5d910b 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -27,6 +27,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( + ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex, ) @@ -163,10 +164,11 @@ class PGVectorIndex(EmbeddingIndex): with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") - async def delete_chunk(self, chunk_id: str) -> None: + async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None: """Remove a chunk from the PostgreSQL table.""" + chunk_ids = [c.chunk_id for c in chunks_for_deletion] with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: - cur.execute(f"DELETE FROM {self.table_name} WHERE id = %s", (chunk_id,)) + cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids,)) class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): @@ -275,12 +277,10 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api) return self.cache[vector_db_id] - async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: """Delete a chunk from a PostgreSQL vector store.""" index = await self._get_and_cache_vector_db_index(store_id) if not index: raise VectorStoreNotFoundError(store_id) - for chunk_id in chunk_ids: - # Use the index's delete_chunk method - await index.index.delete_chunk(chunk_id) + await index.index.delete_chunks(chunks_for_deletion) diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 144da0f4f..018015780 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -29,6 +29,7 @@ from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig a from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( + ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex, ) @@ -88,15 +89,16 @@ class QdrantIndex(EmbeddingIndex): await self.client.upsert(collection_name=self.collection_name, points=points) - async def delete_chunk(self, chunk_id: str) -> None: + async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None: """Remove a chunk from the Qdrant collection.""" + chunk_ids = [convert_id(c.chunk_id) for c in chunks_for_deletion] try: await self.client.delete( collection_name=self.collection_name, - points_selector=models.PointIdsList(points=[convert_id(chunk_id)]), + points_selector=models.PointIdsList(points=chunk_ids), ) except Exception as e: - log.error(f"Error deleting chunk {chunk_id} from Qdrant collection {self.collection_name}: {e}") + log.error(f"Error deleting chunks from Qdrant collection {self.collection_name}: {e}") raise async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: @@ -264,12 +266,14 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP ) -> VectorStoreFileObject: # Qdrant doesn't allow multiple clients to access the same storage path simultaneously. async with self._qdrant_lock: - await super().openai_attach_file_to_vector_store(vector_store_id, file_id, attributes, chunking_strategy) + return await super().openai_attach_file_to_vector_store( + vector_store_id, file_id, attributes, chunking_strategy + ) - async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: """Delete chunks from a Qdrant vector store.""" index = await self._get_and_cache_vector_db_index(store_id) if not index: raise ValueError(f"Vector DB {store_id} not found") - for chunk_id in chunk_ids: - await index.index.delete_chunk(chunk_id) + + await index.index.delete_chunks(chunks_for_deletion) diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index 11da8902c..966724848 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -26,6 +26,7 @@ from llama_stack.providers.utils.memory.openai_vector_store_mixin import ( OpenAIVectorStoreMixin, ) from llama_stack.providers.utils.memory.vector_store import ( + ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex, ) @@ -67,6 +68,7 @@ class WeaviateIndex(EmbeddingIndex): data_objects.append( wvc.data.DataObject( properties={ + "chunk_id": chunk.chunk_id, "chunk_content": chunk.model_dump_json(), }, vector=embeddings[i].tolist(), @@ -79,10 +81,11 @@ class WeaviateIndex(EmbeddingIndex): # TODO: make this async friendly collection.data.insert_many(data_objects) - async def delete_chunk(self, chunk_id: str) -> None: + async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None: sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True) collection = self.client.collections.get(sanitized_collection_name) - collection.data.delete_many(where=Filter.by_property("id").contains_any([chunk_id])) + chunk_ids = [chunk.chunk_id for chunk in chunks_for_deletion] + collection.data.delete_many(where=Filter.by_property("chunk_id").contains_any(chunk_ids)) async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True) @@ -307,10 +310,10 @@ class WeaviateVectorIOAdapter( return await index.query_chunks(query, params) - async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: sanitized_collection_name = sanitize_collection_name(store_id, weaviate_format=True) index = await self._get_and_cache_vector_db_index(sanitized_collection_name) if not index: raise ValueError(f"Vector DB {sanitized_collection_name} not found") - await index.delete(chunk_ids) + await index.index.delete_chunks(chunks_for_deletion) diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index e6e5ccc8a..6297cc2ed 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -31,15 +31,15 @@ from openai.types.chat import ( from openai.types.chat import ( ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam, ) +from openai.types.chat import ( + ChatCompletionMessageFunctionToolCall as OpenAIChatCompletionMessageFunctionToolCall, +) from openai.types.chat import ( ChatCompletionMessageParam as OpenAIChatCompletionMessage, ) from openai.types.chat import ( ChatCompletionMessageToolCall, ) -from openai.types.chat import ( - ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall, -) from openai.types.chat import ( ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, ) @@ -70,7 +70,7 @@ from openai.types.chat.chat_completion_chunk import ( from openai.types.chat.chat_completion_content_part_image_param import ( ImageURL as OpenAIImageURL, ) -from openai.types.chat.chat_completion_message_tool_call_param import ( +from openai.types.chat.chat_completion_message_tool_call import ( Function as OpenAIFunction, ) from pydantic import BaseModel @@ -633,7 +633,7 @@ async def convert_message_to_openai_dict_new( ) elif isinstance(message, CompletionMessage): tool_calls = [ - OpenAIChatCompletionMessageToolCall( + OpenAIChatCompletionMessageFunctionToolCall( id=tool.call_id, function=OpenAIFunction( name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value), @@ -903,7 +903,7 @@ def _convert_openai_request_response_format( def _convert_openai_tool_calls( - tool_calls: list[OpenAIChatCompletionMessageToolCall], + tool_calls: list[OpenAIChatCompletionMessageFunctionToolCall], ) -> list[ToolCall]: """ Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall. diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index 7b6e69df1..120d0d4fc 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -6,7 +6,6 @@ import asyncio import json -import logging import mimetypes import time import uuid @@ -37,10 +36,15 @@ from llama_stack.apis.vector_io import ( VectorStoreSearchResponse, VectorStoreSearchResponsePage, ) +from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore.api import KVStore -from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks +from llama_stack.providers.utils.memory.vector_store import ( + ChunkForDeletion, + content_from_data_and_mime_type, + make_overlapped_chunks, +) -logger = logging.getLogger(__name__) +logger = get_logger(__name__, category="vector_io") # Constants for OpenAI vector stores CHUNK_MULTIPLIER = 5 @@ -154,8 +158,8 @@ class OpenAIVectorStoreMixin(ABC): self.openai_vector_stores = await self._load_openai_vector_stores() @abstractmethod - async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: - """Delete a chunk from a vector store.""" + async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: + """Delete chunks from a vector store.""" pass @abstractmethod @@ -614,7 +618,7 @@ class OpenAIVectorStoreMixin(ABC): ) vector_store_file_object.status = "completed" except Exception as e: - logger.error(f"Error attaching file to vector store: {e}") + logger.exception("Error attaching file to vector store") vector_store_file_object.status = "failed" vector_store_file_object.last_error = VectorStoreFileLastError( code="server_error", @@ -767,7 +771,21 @@ class OpenAIVectorStoreMixin(ABC): dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id) chunks = [Chunk.model_validate(c) for c in dict_chunks] - await self.delete_chunks(vector_store_id, [str(c.chunk_id) for c in chunks if c.chunk_id]) + + # Create ChunkForDeletion objects with both chunk_id and document_id + chunks_for_deletion = [] + for c in chunks: + if c.chunk_id: + document_id = c.metadata.get("document_id") or ( + c.chunk_metadata.document_id if c.chunk_metadata else None + ) + if document_id: + chunks_for_deletion.append(ChunkForDeletion(chunk_id=str(c.chunk_id), document_id=document_id)) + else: + logger.warning(f"Chunk {c.chunk_id} has no document_id, skipping deletion") + + if chunks_for_deletion: + await self.delete_chunks(vector_store_id, chunks_for_deletion) store_info = self.openai_vector_stores[vector_store_id].copy() diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index bb9002f30..6ae5bb521 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -16,6 +16,7 @@ from urllib.parse import unquote import httpx import numpy as np from numpy.typing import NDArray +from pydantic import BaseModel from llama_stack.apis.common.content_types import ( URL, @@ -34,6 +35,18 @@ from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id log = logging.getLogger(__name__) + +class ChunkForDeletion(BaseModel): + """Information needed to delete a chunk from a vector store. + + :param chunk_id: The ID of the chunk to delete + :param document_id: The ID of the document this chunk belongs to + """ + + chunk_id: str + document_id: str + + # Constants for reranker types RERANKER_TYPE_RRF = "rrf" RERANKER_TYPE_WEIGHTED = "weighted" @@ -232,7 +245,7 @@ class EmbeddingIndex(ABC): raise NotImplementedError() @abstractmethod - async def delete_chunk(self, chunk_id: str): + async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]): raise NotImplementedError() @abstractmethod diff --git a/llama_stack/ui/.nvmrc b/llama_stack/ui/.nvmrc new file mode 100644 index 000000000..1384ff6a1 --- /dev/null +++ b/llama_stack/ui/.nvmrc @@ -0,0 +1 @@ +22.5.1 diff --git a/llama_stack/ui/.prettierignore b/llama_stack/ui/.prettierignore index 1b8ac8894..b737ae6ed 100644 --- a/llama_stack/ui/.prettierignore +++ b/llama_stack/ui/.prettierignore @@ -1,3 +1,12 @@ # Ignore artifacts: build coverage +.next +node_modules +dist +*.lock +*.log + +# Generated files +*.min.js +*.min.css diff --git a/llama_stack/ui/.prettierrc b/llama_stack/ui/.prettierrc index 0967ef424..059475a24 100644 --- a/llama_stack/ui/.prettierrc +++ b/llama_stack/ui/.prettierrc @@ -1 +1,10 @@ -{} +{ + "semi": true, + "trailingComma": "es5", + "singleQuote": false, + "printWidth": 80, + "tabWidth": 2, + "useTabs": false, + "bracketSpacing": true, + "arrowParens": "avoid" +} diff --git a/llama_stack/ui/app/api/v1/[...path]/route.ts b/llama_stack/ui/app/api/v1/[...path]/route.ts index 1959f9099..51c1f8004 100644 --- a/llama_stack/ui/app/api/v1/[...path]/route.ts +++ b/llama_stack/ui/app/api/v1/[...path]/route.ts @@ -47,7 +47,7 @@ async function proxyRequest(request: NextRequest, method: string) { const responseText = await response.text(); console.log( - `Response from FastAPI: ${response.status} ${response.statusText}`, + `Response from FastAPI: ${response.status} ${response.statusText}` ); // Create response with same status and headers @@ -74,7 +74,7 @@ async function proxyRequest(request: NextRequest, method: string) { backend_url: BACKEND_URL, timestamp: new Date().toISOString(), }, - { status: 500 }, + { status: 500 } ); } } diff --git a/llama_stack/ui/app/auth/signin/page.tsx b/llama_stack/ui/app/auth/signin/page.tsx index c9510fd6b..0ccb4a397 100644 --- a/llama_stack/ui/app/auth/signin/page.tsx +++ b/llama_stack/ui/app/auth/signin/page.tsx @@ -51,9 +51,9 @@ export default function SignInPage() { onClick={() => { console.log("Signing in with GitHub..."); signIn("github", { callbackUrl: "/auth/signin" }).catch( - (error) => { + error => { console.error("Sign in error:", error); - }, + } ); }} className="w-full" diff --git a/llama_stack/ui/app/chat-playground/page.tsx b/llama_stack/ui/app/chat-playground/page.tsx index d8094af85..b8651aca0 100644 --- a/llama_stack/ui/app/chat-playground/page.tsx +++ b/llama_stack/ui/app/chat-playground/page.tsx @@ -29,14 +29,13 @@ export default function ChatPlaygroundPage() { const isModelsLoading = modelsLoading ?? true; - useEffect(() => { const fetchModels = async () => { try { setModelsLoading(true); setModelsError(null); const modelList = await client.models.list(); - const llmModels = modelList.filter(model => model.model_type === 'llm'); + const llmModels = modelList.filter(model => model.model_type === "llm"); setModels(llmModels); if (llmModels.length > 0) { setSelectedModel(llmModels[0].identifier); @@ -53,103 +52,122 @@ export default function ChatPlaygroundPage() { }, [client]); const extractTextContent = (content: unknown): string => { - if (typeof content === 'string') { + if (typeof content === "string") { return content; } if (Array.isArray(content)) { return content - .filter(item => item && typeof item === 'object' && 'type' in item && item.type === 'text') - .map(item => (item && typeof item === 'object' && 'text' in item) ? String(item.text) : '') - .join(''); + .filter( + item => + item && + typeof item === "object" && + "type" in item && + item.type === "text" + ) + .map(item => + item && typeof item === "object" && "text" in item + ? String(item.text) + : "" + ) + .join(""); } - if (content && typeof content === 'object' && 'type' in content && content.type === 'text' && 'text' in content) { - return String(content.text) || ''; + if ( + content && + typeof content === "object" && + "type" in content && + content.type === "text" && + "text" in content + ) { + return String(content.text) || ""; } - return ''; + return ""; }; const handleInputChange = (e: React.ChangeEvent) => { setInput(e.target.value); }; -const handleSubmit = async (event?: { preventDefault?: () => void }) => { - event?.preventDefault?.(); - if (!input.trim()) return; + const handleSubmit = async (event?: { preventDefault?: () => void }) => { + event?.preventDefault?.(); + if (!input.trim()) return; - // Add user message to chat - const userMessage: Message = { - id: Date.now().toString(), - role: "user", - content: input.trim(), - createdAt: new Date(), - }; - - setMessages(prev => [...prev, userMessage]); - setInput(""); - - // Use the helper function with the content - await handleSubmitWithContent(userMessage.content); -}; - -const handleSubmitWithContent = async (content: string) => { - setIsGenerating(true); - setError(null); - - try { - const messageParams: CompletionCreateParams["messages"] = [ - ...messages.map(msg => { - const msgContent = typeof msg.content === 'string' ? msg.content : extractTextContent(msg.content); - if (msg.role === "user") { - return { role: "user" as const, content: msgContent }; - } else if (msg.role === "assistant") { - return { role: "assistant" as const, content: msgContent }; - } else { - return { role: "system" as const, content: msgContent }; - } - }), - { role: "user" as const, content } - ]; - - const response = await client.chat.completions.create({ - model: selectedModel, - messages: messageParams, - stream: true, - }); - - const assistantMessage: Message = { - id: (Date.now() + 1).toString(), - role: "assistant", - content: "", + // Add user message to chat + const userMessage: Message = { + id: Date.now().toString(), + role: "user", + content: input.trim(), createdAt: new Date(), }; - setMessages(prev => [...prev, assistantMessage]); - let fullContent = ""; - for await (const chunk of response) { - if (chunk.choices && chunk.choices[0]?.delta?.content) { - const deltaContent = chunk.choices[0].delta.content; - fullContent += deltaContent; + setMessages(prev => [...prev, userMessage]); + setInput(""); - flushSync(() => { - setMessages(prev => { - const newMessages = [...prev]; - const lastMessage = newMessages[newMessages.length - 1]; - if (lastMessage.role === "assistant") { - lastMessage.content = fullContent; - } - return newMessages; + // Use the helper function with the content + await handleSubmitWithContent(userMessage.content); + }; + + const handleSubmitWithContent = async (content: string) => { + setIsGenerating(true); + setError(null); + + try { + const messageParams: CompletionCreateParams["messages"] = [ + ...messages.map(msg => { + const msgContent = + typeof msg.content === "string" + ? msg.content + : extractTextContent(msg.content); + if (msg.role === "user") { + return { role: "user" as const, content: msgContent }; + } else if (msg.role === "assistant") { + return { role: "assistant" as const, content: msgContent }; + } else { + return { role: "system" as const, content: msgContent }; + } + }), + { role: "user" as const, content }, + ]; + + const response = await client.chat.completions.create({ + model: selectedModel, + messages: messageParams, + stream: true, + }); + + const assistantMessage: Message = { + id: (Date.now() + 1).toString(), + role: "assistant", + content: "", + createdAt: new Date(), + }; + + setMessages(prev => [...prev, assistantMessage]); + let fullContent = ""; + for await (const chunk of response) { + if (chunk.choices && chunk.choices[0]?.delta?.content) { + const deltaContent = chunk.choices[0].delta.content; + fullContent += deltaContent; + + flushSync(() => { + setMessages(prev => { + const newMessages = [...prev]; + const lastMessage = newMessages[newMessages.length - 1]; + if (lastMessage.role === "assistant") { + lastMessage.content = fullContent; + } + return newMessages; + }); }); - }); + } } + } catch (err) { + console.error("Error sending message:", err); + setError("Failed to send message. Please try again."); + setMessages(prev => prev.slice(0, -1)); + } finally { + setIsGenerating(false); } - } catch (err) { - console.error("Error sending message:", err); - setError("Failed to send message. Please try again."); - setMessages(prev => prev.slice(0, -1)); - } finally { - setIsGenerating(false); - } -}; + }; const suggestions = [ "Write a Python function that prints 'Hello, World!'", "Explain step-by-step how to solve this math problem: If xΒ² + 6x + 9 = 25, what is x?", @@ -163,7 +181,7 @@ const handleSubmitWithContent = async (content: string) => { content: message.content, createdAt: new Date(), }; - setMessages(prev => [...prev, newMessage]) + setMessages(prev => [...prev, newMessage]); handleSubmitWithContent(newMessage.content); }; @@ -177,12 +195,20 @@ const handleSubmitWithContent = async (content: string) => {

Chat Playground (Completions)

- - + - {models.map((model) => ( + {models.map(model => ( {model.identifier} diff --git a/llama_stack/ui/app/logs/chat-completions/[id]/page.tsx b/llama_stack/ui/app/logs/chat-completions/[id]/page.tsx index 82aa3496e..e11924f4c 100644 --- a/llama_stack/ui/app/logs/chat-completions/[id]/page.tsx +++ b/llama_stack/ui/app/logs/chat-completions/[id]/page.tsx @@ -33,12 +33,12 @@ export default function ChatCompletionDetailPage() { } catch (err) { console.error( `Error fetching chat completion detail for ID ${id}:`, - err, + err ); setError( err instanceof Error ? err - : new Error("Failed to fetch completion detail"), + : new Error("Failed to fetch completion detail") ); } finally { setIsLoading(false); diff --git a/llama_stack/ui/app/logs/responses/[id]/page.tsx b/llama_stack/ui/app/logs/responses/[id]/page.tsx index 7f4252856..922d35531 100644 --- a/llama_stack/ui/app/logs/responses/[id]/page.tsx +++ b/llama_stack/ui/app/logs/responses/[id]/page.tsx @@ -13,10 +13,10 @@ export default function ResponseDetailPage() { const client = useAuthClient(); const [responseDetail, setResponseDetail] = useState( - null, + null ); const [inputItems, setInputItems] = useState( - null, + null ); const [isLoading, setIsLoading] = useState(true); const [isLoadingInputItems, setIsLoadingInputItems] = useState(true); @@ -25,7 +25,7 @@ export default function ResponseDetailPage() { // Helper function to convert ResponseObject to OpenAIResponse const convertResponseObject = ( - responseData: ResponseObject, + responseData: ResponseObject ): OpenAIResponse => { return { id: responseData.id, @@ -73,12 +73,12 @@ export default function ResponseDetailPage() { } else { console.error( `Error fetching response detail for ID ${id}:`, - responseResult.reason, + responseResult.reason ); setError( responseResult.reason instanceof Error ? responseResult.reason - : new Error("Failed to fetch response detail"), + : new Error("Failed to fetch response detail") ); } @@ -90,18 +90,18 @@ export default function ResponseDetailPage() { } else { console.error( `Error fetching input items for response ID ${id}:`, - inputItemsResult.reason, + inputItemsResult.reason ); setInputItemsError( inputItemsResult.reason instanceof Error ? inputItemsResult.reason - : new Error("Failed to fetch input items"), + : new Error("Failed to fetch input items") ); } } catch (err) { console.error(`Unexpected error fetching data for ID ${id}:`, err); setError( - err instanceof Error ? err : new Error("Unexpected error occurred"), + err instanceof Error ? err : new Error("Unexpected error occurred") ); } finally { setIsLoading(false); diff --git a/llama_stack/ui/app/logs/vector-stores/[id]/files/[fileId]/contents/[contentId]/page.tsx b/llama_stack/ui/app/logs/vector-stores/[id]/files/[fileId]/contents/[contentId]/page.tsx index 6896b992a..d58de3085 100644 --- a/llama_stack/ui/app/logs/vector-stores/[id]/files/[fileId]/contents/[contentId]/page.tsx +++ b/llama_stack/ui/app/logs/vector-stores/[id]/files/[fileId]/contents/[contentId]/page.tsx @@ -18,7 +18,10 @@ import { PropertiesCard, PropertyItem, } from "@/components/layout/detail-layout"; -import { PageBreadcrumb, BreadcrumbSegment } from "@/components/layout/page-breadcrumb"; +import { + PageBreadcrumb, + BreadcrumbSegment, +} from "@/components/layout/page-breadcrumb"; export default function ContentDetailPage() { const params = useParams(); @@ -28,13 +31,13 @@ export default function ContentDetailPage() { const contentId = params.contentId as string; const client = useAuthClient(); - const getTextFromContent = (content: any): string => { - if (typeof content === 'string') { + const getTextFromContent = (content: unknown): string => { + if (typeof content === "string") { return content; - } else if (content && content.type === 'text') { + } else if (content && content.type === "text") { return content.text; } - return ''; + return ""; }; const [store, setStore] = useState(null); @@ -44,7 +47,9 @@ export default function ContentDetailPage() { const [error, setError] = useState(null); const [isEditing, setIsEditing] = useState(false); const [editedContent, setEditedContent] = useState(""); - const [editedMetadata, setEditedMetadata] = useState>({}); + const [editedMetadata, setEditedMetadata] = useState>( + {} + ); const [isEditingEmbedding, setIsEditingEmbedding] = useState(false); const [editedEmbedding, setEditedEmbedding] = useState([]); @@ -64,8 +69,13 @@ export default function ContentDetailPage() { setFile(fileResponse as VectorStoreFile); const contentsAPI = new ContentsAPI(client); - const contentsResponse = await contentsAPI.listContents(vectorStoreId, fileId); - const targetContent = contentsResponse.data.find(c => c.id === contentId); + const contentsResponse = await contentsAPI.listContents( + vectorStoreId, + fileId + ); + const targetContent = contentsResponse.data.find( + c => c.id === contentId + ); if (targetContent) { setContent(targetContent); @@ -76,7 +86,9 @@ export default function ContentDetailPage() { throw new Error(`Content ${contentId} not found`); } } catch (err) { - setError(err instanceof Error ? err : new Error("Failed to load content.")); + setError( + err instanceof Error ? err : new Error("Failed to load content.") + ); } finally { setIsLoading(false); } @@ -88,7 +100,8 @@ export default function ContentDetailPage() { if (!content) return; try { - const updates: { content?: string; metadata?: Record } = {}; + const updates: { content?: string; metadata?: Record } = + {}; if (editedContent !== getTextFromContent(content.content)) { updates.content = editedContent; @@ -100,25 +113,32 @@ export default function ContentDetailPage() { if (Object.keys(updates).length > 0) { const contentsAPI = new ContentsAPI(client); - const updatedContent = await contentsAPI.updateContent(vectorStoreId, fileId, contentId, updates); + const updatedContent = await contentsAPI.updateContent( + vectorStoreId, + fileId, + contentId, + updates + ); setContent(updatedContent); } setIsEditing(false); } catch (err) { - console.error('Failed to update content:', err); + console.error("Failed to update content:", err); } }; const handleDelete = async () => { - if (!confirm('Are you sure you want to delete this content?')) return; + if (!confirm("Are you sure you want to delete this content?")) return; try { const contentsAPI = new ContentsAPI(client); await contentsAPI.deleteContent(vectorStoreId, fileId, contentId); - router.push(`/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents`); + router.push( + `/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents` + ); } catch (err) { - console.error('Failed to delete content:', err); + console.error("Failed to delete content:", err); } }; @@ -134,10 +154,19 @@ export default function ContentDetailPage() { const breadcrumbSegments: BreadcrumbSegment[] = [ { label: "Vector Stores", href: "/logs/vector-stores" }, - { label: store?.name || vectorStoreId, href: `/logs/vector-stores/${vectorStoreId}` }, + { + label: store?.name || vectorStoreId, + href: `/logs/vector-stores/${vectorStoreId}`, + }, { label: "Files", href: `/logs/vector-stores/${vectorStoreId}` }, - { label: fileId, href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}` }, - { label: "Contents", href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents` }, + { + label: fileId, + href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}`, + }, + { + label: "Contents", + href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents`, + }, { label: contentId }, ]; @@ -186,7 +215,7 @@ export default function ContentDetailPage() { {isEditing ? (