diff --git a/.github/actions/setup-runner/action.yml b/.github/actions/setup-runner/action.yml index 0be999fe2..1ca02bbff 100644 --- a/.github/actions/setup-runner/action.yml +++ b/.github/actions/setup-runner/action.yml @@ -28,7 +28,7 @@ runs: # Install llama-stack-client-python based on the client-version input if [ "${{ inputs.client-version }}" = "latest" ]; then echo "Installing latest llama-stack-client-python from main branch" - uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main + uv pip install git+https://github.com/llamastack/llama-stack-client-python.git@main elif [ "${{ inputs.client-version }}" = "published" ]; then echo "Installing published llama-stack-client-python from PyPI" uv pip install llama-stack-client diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index a38d4971a..9ef49fba3 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -52,7 +52,8 @@ jobs: run: | # Get test directories dynamically, excluding non-test directories # NOTE: we are excluding post_training since the tests take too long - TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d -printf "%f\n" | + TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d | + sed 's|tests/integration/||' | grep -Ev "^(__pycache__|fixtures|test_cases|recordings|non_ci|post_training)$" | sort | jq -R -s -c 'split("\n")[:-1]') echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT diff --git a/.github/workflows/integration-vector-io-tests.yml b/.github/workflows/integration-vector-io-tests.yml index aa239572b..99a44c147 100644 --- a/.github/workflows/integration-vector-io-tests.yml +++ b/.github/workflows/integration-vector-io-tests.yml @@ -14,9 +14,11 @@ on: - 'pyproject.toml' - 'requirements.txt' - '.github/workflows/integration-vector-io-tests.yml' # This workflow + schedule: + - cron: '0 0 * * *' # (test on python 3.13) Daily at 12 AM UTC concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: @@ -25,7 +27,7 @@ jobs: strategy: matrix: vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::weaviate", "remote::qdrant"] - python-version: ["3.12", "3.13"] + python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }} fail-fast: false # we want to run all tests regardless of failure steps: @@ -164,9 +166,9 @@ jobs: ENABLE_WEAVIATE: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'true' || '' }} WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }} run: | - uv run pytest -sv --stack-config="inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \ + uv run pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \ tests/integration/vector_io \ - --embedding-model sentence-transformers/all-MiniLM-L6-v2 + --embedding-model inline::sentence-transformers/all-MiniLM-L6-v2 - name: Check Storage and Memory Available After Tests if: ${{ always() }} diff --git a/.github/workflows/record-integration-tests.yml b/.github/workflows/record-integration-tests.yml index 12957db27..b31709a4f 100644 --- a/.github/workflows/record-integration-tests.yml +++ b/.github/workflows/record-integration-tests.yml @@ -3,7 +3,7 @@ name: Integration Tests (Record) run-name: Run the integration test suite from tests/integration on: - pull_request: + pull_request_target: branches: [ main ] types: [opened, synchronize, labeled] paths: @@ -23,7 +23,7 @@ on: default: 'ollama' concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event.pull_request.number }} cancel-in-progress: true jobs: diff --git a/.github/workflows/semantic-pr.yml b/.github/workflows/semantic-pr.yml index 4df7324c4..57a4df646 100644 --- a/.github/workflows/semantic-pr.yml +++ b/.github/workflows/semantic-pr.yml @@ -11,7 +11,7 @@ on: - synchronize concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event.pull_request.number }} cancel-in-progress: true permissions: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 30843173c..4309f289a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,6 +2,7 @@ exclude: 'build/' default_language_version: python: python3.12 + node: "22" repos: - repo: https://github.com/pre-commit/pre-commit-hooks @@ -145,6 +146,20 @@ repos: pass_filenames: false require_serial: true files: ^.github/workflows/.*$ + - id: ui-prettier + name: Format UI code with Prettier + entry: bash -c 'cd llama_stack/ui && npm run format' + language: system + files: ^llama_stack/ui/.*\.(ts|tsx)$ + pass_filenames: false + require_serial: true + - id: ui-eslint + name: Lint UI code with ESLint + entry: bash -c 'cd llama_stack/ui && npm run lint -- --fix --quiet' + language: system + files: ^llama_stack/ui/.*\.(ts|tsx)$ + pass_filenames: false + require_serial: true ci: autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 066fcecf0..c81e9e7b1 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,13 +1,82 @@ -# Contributing to Llama-Stack +# Contributing to Llama Stack We want to make contributing to this project as easy and transparent as possible. +## Set up your development environment + +We use [uv](https://github.com/astral-sh/uv) to manage python dependencies and virtual environments. +You can install `uv` by following this [guide](https://docs.astral.sh/uv/getting-started/installation/). + +You can install the dependencies by running: + +```bash +cd llama-stack +uv sync --group dev +uv pip install -e . +source .venv/bin/activate +``` + +```{note} +You can use a specific version of Python with `uv` by adding the `--python ` flag (e.g. `--python 3.12`). +Otherwise, `uv` will automatically select a Python version according to the `requires-python` section of the `pyproject.toml`. +For more info, see the [uv docs around Python versions](https://docs.astral.sh/uv/concepts/python-versions/). +``` + +Note that you can create a dotenv file `.env` that includes necessary environment variables: +``` +LLAMA_STACK_BASE_URL=http://localhost:8321 +LLAMA_STACK_CLIENT_LOG=debug +LLAMA_STACK_PORT=8321 +LLAMA_STACK_CONFIG= +TAVILY_SEARCH_API_KEY= +BRAVE_SEARCH_API_KEY= +``` + +And then use this dotenv file when running client SDK tests via the following: +```bash +uv run --env-file .env -- pytest -v tests/integration/inference/test_text_inference.py --text-model=meta-llama/Llama-3.1-8B-Instruct +``` + +### Pre-commit Hooks + +We use [pre-commit](https://pre-commit.com/) to run linting and formatting checks on your code. You can install the pre-commit hooks by running: + +```bash +uv run pre-commit install +``` + +After that, pre-commit hooks will run automatically before each commit. + +Alternatively, if you don't want to install the pre-commit hooks, you can run the checks manually by running: + +```bash +uv run pre-commit run --all-files +``` + +```{caution} +Before pushing your changes, make sure that the pre-commit hooks have passed successfully. +``` + ## Discussions -> Issues -> Pull Requests We actively welcome your pull requests. However, please read the following. This is heavily inspired by [Ghostty](https://github.com/ghostty-org/ghostty/blob/main/CONTRIBUTING.md). If in doubt, please open a [discussion](https://github.com/meta-llama/llama-stack/discussions); we can always convert that to an issue later. +### Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Meta has a [bounty program](http://facebook.com/whitehat/info) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +### Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Meta's open source projects. + +Complete your CLA here: + **I'd like to contribute!** If you are new to the project, start by looking at the issues tagged with "good first issue". If you're interested @@ -51,93 +120,15 @@ Please avoid picking up too many issues at once. This helps you stay focused and Please keep pull requests (PRs) small and focused. If you have a large set of changes, consider splitting them into logically grouped, smaller PRs to facilitate review and testing. -> [!TIP] -> As a general guideline: -> - Experienced contributors should try to keep no more than 5 open PRs at a time. -> - New contributors are encouraged to have only one open PR at a time until they’re familiar with the codebase and process. - -## Contributor License Agreement ("CLA") -In order to accept your pull request, we need you to submit a CLA. You only need -to do this once to work on any of Meta's open source projects. - -Complete your CLA here: - -## Issues -We use GitHub issues to track public bugs. Please ensure your description is -clear and has sufficient instructions to be able to reproduce the issue. - -Meta has a [bounty program](http://facebook.com/whitehat/info) for the safe -disclosure of security bugs. In those cases, please go through the process -outlined on that page and do not file a public issue. - - -## Set up your development environment - -We use [uv](https://github.com/astral-sh/uv) to manage python dependencies and virtual environments. -You can install `uv` by following this [guide](https://docs.astral.sh/uv/getting-started/installation/). - -You can install the dependencies by running: - -```bash -cd llama-stack -uv sync --group dev -uv pip install -e . -source .venv/bin/activate +```{tip} +As a general guideline: +- Experienced contributors should try to keep no more than 5 open PRs at a time. +- New contributors are encouraged to have only one open PR at a time until they’re familiar with the codebase and process. ``` -> [!NOTE] -> You can use a specific version of Python with `uv` by adding the `--python ` flag (e.g. `--python 3.12`) -> Otherwise, `uv` will automatically select a Python version according to the `requires-python` section of the `pyproject.toml`. -> For more info, see the [uv docs around Python versions](https://docs.astral.sh/uv/concepts/python-versions/). +## Repository guidelines -Note that you can create a dotenv file `.env` that includes necessary environment variables: -``` -LLAMA_STACK_BASE_URL=http://localhost:8321 -LLAMA_STACK_CLIENT_LOG=debug -LLAMA_STACK_PORT=8321 -LLAMA_STACK_CONFIG= -TAVILY_SEARCH_API_KEY= -BRAVE_SEARCH_API_KEY= -``` - -And then use this dotenv file when running client SDK tests via the following: -```bash -uv run --env-file .env -- pytest -v tests/integration/inference/test_text_inference.py --text-model=meta-llama/Llama-3.1-8B-Instruct -``` - -## Pre-commit Hooks - -We use [pre-commit](https://pre-commit.com/) to run linting and formatting checks on your code. You can install the pre-commit hooks by running: - -```bash -uv run pre-commit install -``` - -After that, pre-commit hooks will run automatically before each commit. - -Alternatively, if you don't want to install the pre-commit hooks, you can run the checks manually by running: - -```bash -uv run pre-commit run --all-files -``` - -> [!CAUTION] -> Before pushing your changes, make sure that the pre-commit hooks have passed successfully. - -## Running tests - -You can find the Llama Stack testing documentation [here](https://github.com/meta-llama/llama-stack/blob/main/tests/README.md). - -## Adding a new dependency to the project - -To add a new dependency to the project, you can use the `uv` command. For example, to add `foo` to the project, you can run: - -```bash -uv add foo -uv sync -``` - -## Coding Style +### Coding Style * Comments should provide meaningful insights into the code. Avoid filler comments that simply describe the next step, as they create unnecessary clutter, same goes for docstrings. @@ -159,6 +150,10 @@ uv sync * When possible, use keyword arguments only when calling functions. * Llama Stack utilizes [custom Exception classes](llama_stack/apis/common/errors.py) for certain Resources that should be used where applicable. +### License +By contributing to Llama, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. + ## Common Tasks Some tips about common tasks you work on while contributing to Llama Stack: @@ -210,8 +205,4 @@ If you modify or add new API endpoints, update the API documentation accordingly uv run ./docs/openapi_generator/run_openapi_generator.sh ``` -The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing. - -## License -By contributing to Llama, you agree that your contributions will be licensed -under the LICENSE file in the root directory of this source tree. +The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing. \ No newline at end of file diff --git a/README.md b/README.md index 03aa3dd50..4df4a5372 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ [**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) | [**Discord**](https://discord.gg/llama-stack) + ### βœ¨πŸŽ‰ Llama 4 Support πŸŽ‰βœ¨ We released [Version 0.2.0](https://github.com/meta-llama/llama-stack/releases/tag/v0.2.0) with support for the Llama 4 herd of models released by Meta. @@ -179,3 +180,17 @@ Please checkout our [Documentation](https://llama-stack.readthedocs.io/en/latest Check out our client SDKs for connecting to a Llama Stack server in your preferred language, you can choose from [python](https://github.com/meta-llama/llama-stack-client-python), [typescript](https://github.com/meta-llama/llama-stack-client-typescript), [swift](https://github.com/meta-llama/llama-stack-client-swift), and [kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) programming languages to quickly build your applications. You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/examples) repo. + + +## 🌟 GitHub Star History +## Star History + +[![Star History Chart](https://api.star-history.com/svg?repos=meta-llama/llama-stack&type=Date)](https://www.star-history.com/#meta-llama/llama-stack&Date) + +## ✨ Contributors + +Thanks to all of our amazing contributors! + + + + \ No newline at end of file diff --git a/docs/_static/js/keyboard_shortcuts.js b/docs/_static/js/keyboard_shortcuts.js new file mode 100644 index 000000000..81d0b7c65 --- /dev/null +++ b/docs/_static/js/keyboard_shortcuts.js @@ -0,0 +1,14 @@ +document.addEventListener('keydown', function(event) { + // command+K or ctrl+K + if ((event.metaKey || event.ctrlKey) && event.key === 'k') { + event.preventDefault(); + document.querySelector('.search-input, .search-field, input[name="q"]').focus(); + } + + // forward slash + if (event.key === '/' && + !event.target.matches('input, textarea, select')) { + event.preventDefault(); + document.querySelector('.search-input, .search-field, input[name="q"]').focus(); + } +}); diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index d480ff592..0549dda21 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -8293,28 +8293,60 @@ "type": "array", "items": { "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" + "properties": { + "attributes": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } + "description": "(Optional) Key-value attributes associated with the file" + }, + "file_id": { + "type": "string", + "description": "Unique identifier of the file containing the result" + }, + "filename": { + "type": "string", + "description": "Name of the file containing the result" + }, + "score": { + "type": "number", + "description": "Relevance score for this search result (between 0 and 1)" + }, + "text": { + "type": "string", + "description": "Text content of the search result" + } + }, + "additionalProperties": false, + "required": [ + "attributes", + "file_id", + "filename", + "score", + "text" + ], + "title": "OpenAIResponseOutputMessageFileSearchToolCallResults", + "description": "Search results returned by the file search operation." }, "description": "(Optional) Search results returned by the file search operation" } @@ -8515,6 +8547,13 @@ "$ref": "#/components/schemas/OpenAIResponseInputTool" } }, + "include": { + "type": "array", + "items": { + "type": "string" + }, + "description": "(Optional) Additional fields to include in the response." + }, "max_infer_iters": { "type": "integer" } @@ -8782,6 +8821,61 @@ "title": "OpenAIResponseOutputMessageMCPListTools", "description": "MCP list tools output message containing available tools from an MCP server." }, + "OpenAIResponseContentPart": { + "oneOf": [ + { + "$ref": "#/components/schemas/OpenAIResponseContentPartOutputText" + }, + { + "$ref": "#/components/schemas/OpenAIResponseContentPartRefusal" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "output_text": "#/components/schemas/OpenAIResponseContentPartOutputText", + "refusal": "#/components/schemas/OpenAIResponseContentPartRefusal" + } + } + }, + "OpenAIResponseContentPartOutputText": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "output_text", + "default": "output_text" + }, + "text": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "type", + "text" + ], + "title": "OpenAIResponseContentPartOutputText" + }, + "OpenAIResponseContentPartRefusal": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "refusal", + "default": "refusal" + }, + "refusal": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "type", + "refusal" + ], + "title": "OpenAIResponseContentPartRefusal" + }, "OpenAIResponseObjectStream": { "oneOf": [ { @@ -8838,6 +8932,12 @@ { "$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallCompleted" }, + { + "$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseContentPartAdded" + }, + { + "$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseContentPartDone" + }, { "$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted" } @@ -8863,6 +8963,8 @@ "response.mcp_call.in_progress": "#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallInProgress", "response.mcp_call.failed": "#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallFailed", "response.mcp_call.completed": "#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallCompleted", + "response.content_part.added": "#/components/schemas/OpenAIResponseObjectStreamResponseContentPartAdded", + "response.content_part.done": "#/components/schemas/OpenAIResponseObjectStreamResponseContentPartDone", "response.completed": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted" } } @@ -8889,6 +8991,80 @@ "title": "OpenAIResponseObjectStreamResponseCompleted", "description": "Streaming event indicating a response has been completed." }, + "OpenAIResponseObjectStreamResponseContentPartAdded": { + "type": "object", + "properties": { + "response_id": { + "type": "string", + "description": "Unique identifier of the response containing this content" + }, + "item_id": { + "type": "string", + "description": "Unique identifier of the output item containing this content part" + }, + "part": { + "$ref": "#/components/schemas/OpenAIResponseContentPart", + "description": "The content part that was added" + }, + "sequence_number": { + "type": "integer", + "description": "Sequential number for ordering streaming events" + }, + "type": { + "type": "string", + "const": "response.content_part.added", + "default": "response.content_part.added", + "description": "Event type identifier, always \"response.content_part.added\"" + } + }, + "additionalProperties": false, + "required": [ + "response_id", + "item_id", + "part", + "sequence_number", + "type" + ], + "title": "OpenAIResponseObjectStreamResponseContentPartAdded", + "description": "Streaming event for when a new content part is added to a response item." + }, + "OpenAIResponseObjectStreamResponseContentPartDone": { + "type": "object", + "properties": { + "response_id": { + "type": "string", + "description": "Unique identifier of the response containing this content" + }, + "item_id": { + "type": "string", + "description": "Unique identifier of the output item containing this content part" + }, + "part": { + "$ref": "#/components/schemas/OpenAIResponseContentPart", + "description": "The completed content part" + }, + "sequence_number": { + "type": "integer", + "description": "Sequential number for ordering streaming events" + }, + "type": { + "type": "string", + "const": "response.content_part.done", + "default": "response.content_part.done", + "description": "Event type identifier, always \"response.content_part.done\"" + } + }, + "additionalProperties": false, + "required": [ + "response_id", + "item_id", + "part", + "sequence_number", + "type" + ], + "title": "OpenAIResponseObjectStreamResponseContentPartDone", + "description": "Streaming event for when a content part is completed." + }, "OpenAIResponseObjectStreamResponseCreated": { "type": "object", "properties": { @@ -16530,7 +16706,7 @@ "additionalProperties": { "type": "number" }, - "description": "A list of the categories along with their scores as predicted by model. Required set of categories that need to be in response - violence - violence/graphic - harassment - harassment/threatening - hate - hate/threatening - illicit - illicit/violent - sexual - sexual/minors - self-harm - self-harm/intent - self-harm/instructions" + "description": "A list of the categories along with their scores as predicted by model." }, "user_message": { "type": "string" diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 9c0fba554..aa47cd58d 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -6021,14 +6021,44 @@ components: type: array items: type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object + properties: + attributes: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: >- + (Optional) Key-value attributes associated with the file + file_id: + type: string + description: >- + Unique identifier of the file containing the result + filename: + type: string + description: Name of the file containing the result + score: + type: number + description: >- + Relevance score for this search result (between 0 and 1) + text: + type: string + description: Text content of the search result + additionalProperties: false + required: + - attributes + - file_id + - filename + - score + - text + title: >- + OpenAIResponseOutputMessageFileSearchToolCallResults + description: >- + Search results returned by the file search operation. description: >- (Optional) Search results returned by the file search operation additionalProperties: false @@ -6188,6 +6218,12 @@ components: type: array items: $ref: '#/components/schemas/OpenAIResponseInputTool' + include: + type: array + items: + type: string + description: >- + (Optional) Additional fields to include in the response. max_infer_iters: type: integer additionalProperties: false @@ -6405,6 +6441,43 @@ components: title: OpenAIResponseOutputMessageMCPListTools description: >- MCP list tools output message containing available tools from an MCP server. + OpenAIResponseContentPart: + oneOf: + - $ref: '#/components/schemas/OpenAIResponseContentPartOutputText' + - $ref: '#/components/schemas/OpenAIResponseContentPartRefusal' + discriminator: + propertyName: type + mapping: + output_text: '#/components/schemas/OpenAIResponseContentPartOutputText' + refusal: '#/components/schemas/OpenAIResponseContentPartRefusal' + OpenAIResponseContentPartOutputText: + type: object + properties: + type: + type: string + const: output_text + default: output_text + text: + type: string + additionalProperties: false + required: + - type + - text + title: OpenAIResponseContentPartOutputText + OpenAIResponseContentPartRefusal: + type: object + properties: + type: + type: string + const: refusal + default: refusal + refusal: + type: string + additionalProperties: false + required: + - type + - refusal + title: OpenAIResponseContentPartRefusal OpenAIResponseObjectStream: oneOf: - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated' @@ -6425,6 +6498,8 @@ components: - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallInProgress' - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallFailed' - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallCompleted' + - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseContentPartAdded' + - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseContentPartDone' - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted' discriminator: propertyName: type @@ -6447,6 +6522,8 @@ components: response.mcp_call.in_progress: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallInProgress' response.mcp_call.failed: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallFailed' response.mcp_call.completed: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallCompleted' + response.content_part.added: '#/components/schemas/OpenAIResponseObjectStreamResponseContentPartAdded' + response.content_part.done: '#/components/schemas/OpenAIResponseObjectStreamResponseContentPartDone' response.completed: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted' "OpenAIResponseObjectStreamResponseCompleted": type: object @@ -6468,6 +6545,76 @@ components: OpenAIResponseObjectStreamResponseCompleted description: >- Streaming event indicating a response has been completed. + "OpenAIResponseObjectStreamResponseContentPartAdded": + type: object + properties: + response_id: + type: string + description: >- + Unique identifier of the response containing this content + item_id: + type: string + description: >- + Unique identifier of the output item containing this content part + part: + $ref: '#/components/schemas/OpenAIResponseContentPart' + description: The content part that was added + sequence_number: + type: integer + description: >- + Sequential number for ordering streaming events + type: + type: string + const: response.content_part.added + default: response.content_part.added + description: >- + Event type identifier, always "response.content_part.added" + additionalProperties: false + required: + - response_id + - item_id + - part + - sequence_number + - type + title: >- + OpenAIResponseObjectStreamResponseContentPartAdded + description: >- + Streaming event for when a new content part is added to a response item. + "OpenAIResponseObjectStreamResponseContentPartDone": + type: object + properties: + response_id: + type: string + description: >- + Unique identifier of the response containing this content + item_id: + type: string + description: >- + Unique identifier of the output item containing this content part + part: + $ref: '#/components/schemas/OpenAIResponseContentPart' + description: The completed content part + sequence_number: + type: integer + description: >- + Sequential number for ordering streaming events + type: + type: string + const: response.content_part.done + default: response.content_part.done + description: >- + Event type identifier, always "response.content_part.done" + additionalProperties: false + required: + - response_id + - item_id + - part + - sequence_number + - type + title: >- + OpenAIResponseObjectStreamResponseContentPartDone + description: >- + Streaming event for when a content part is completed. "OpenAIResponseObjectStreamResponseCreated": type: object properties: @@ -12286,10 +12433,6 @@ components: type: number description: >- A list of the categories along with their scores as predicted by model. - Required set of categories that need to be in response - violence - violence/graphic - - harassment - harassment/threatening - hate - hate/threatening - illicit - - illicit/violent - sexual - sexual/minors - self-harm - self-harm/intent - - self-harm/instructions user_message: type: string metadata: diff --git a/docs/source/apis/external.md b/docs/source/apis/external.md index cc13deb9b..5831990b0 100644 --- a/docs/source/apis/external.md +++ b/docs/source/apis/external.md @@ -111,7 +111,7 @@ name = "llama-stack-api-weather" version = "0.1.0" description = "Weather API for Llama Stack" readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.12" dependencies = ["llama-stack", "pydantic"] [build-system] @@ -231,7 +231,7 @@ name = "llama-stack-provider-kaze" version = "0.1.0" description = "Kaze weather provider for Llama Stack" readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.12" dependencies = ["llama-stack", "pydantic", "aiohttp"] [build-system] diff --git a/docs/source/building_applications/responses_vs_agents.md b/docs/source/building_applications/responses_vs_agents.md index 3eebfb460..5abe951d6 100644 --- a/docs/source/building_applications/responses_vs_agents.md +++ b/docs/source/building_applications/responses_vs_agents.md @@ -2,7 +2,9 @@ Llama Stack (LLS) provides two different APIs for building AI applications with tool calling capabilities: the **Agents API** and the **OpenAI Responses API**. While both enable AI systems to use tools, and maintain full conversation history, they serve different use cases and have distinct characteristics. -> **Note:** For simple and basic inferencing, you may want to use the [Chat Completions API](https://llama-stack.readthedocs.io/en/latest/providers/index.html#chat-completions) directly, before progressing to Agents or Responses API. +```{note} +For simple and basic inferencing, you may want to use the [Chat Completions API](https://llama-stack.readthedocs.io/en/latest/providers/index.html#chat-completions) directly, before progressing to Agents or Responses API. +``` ## Overview diff --git a/docs/source/building_applications/tools.md b/docs/source/building_applications/tools.md index b19be888c..8a54290ed 100644 --- a/docs/source/building_applications/tools.md +++ b/docs/source/building_applications/tools.md @@ -76,7 +76,9 @@ Features: - Context retrieval with token limits -> **Note:** By default, llama stack run.yaml defines toolgroups for web search, wolfram alpha and rag, that are provided by tavily-search, wolfram-alpha and rag providers. +```{note} +By default, llama stack run.yaml defines toolgroups for web search, wolfram alpha and rag, that are provided by tavily-search, wolfram-alpha and rag providers. +``` ## Model Context Protocol (MCP) diff --git a/docs/source/conf.py b/docs/source/conf.py index 20f1abf00..3f84d1310 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -131,6 +131,7 @@ html_static_path = ["../_static"] def setup(app): app.add_css_file("css/my_theme.css") app.add_js_file("js/detect_theme.js") + app.add_js_file("js/keyboard_shortcuts.js") def dockerhub_role(name, rawtext, text, lineno, inliner, options={}, content=[]): url = f"https://hub.docker.com/r/llamastack/{text}" diff --git a/docs/source/contributing/index.md b/docs/source/contributing/index.md index 1e067ea6c..296a49f24 100644 --- a/docs/source/contributing/index.md +++ b/docs/source/contributing/index.md @@ -2,14 +2,33 @@ ```{include} ../../../CONTRIBUTING.md ``` -See the [Adding a New API Provider](new_api_provider.md) which describes how to add new API providers to the Stack. - +## Adding a New Provider +See: +- [Adding a New API Provider Page](new_api_provider.md) which describes how to add new API providers to the Stack. +- [Vector Database Page](new_vector_database.md) which describes how to add a new vector databases with Llama Stack. +- [External Provider Page](../providers/external/index.md) which describes how to add external providers to the Stack. ```{toctree} :maxdepth: 1 :hidden: new_api_provider -testing +new_vector_database +``` + +## Testing + + +```{include} ../../../tests/README.md +``` + +### Advanced Topics + +For developers who need deeper understanding of the testing system internals: + +```{toctree} +:maxdepth: 1 + +testing/record-replay ``` diff --git a/docs/source/contributing/new_vector_database.md b/docs/source/contributing/new_vector_database.md new file mode 100644 index 000000000..83c0f55bc --- /dev/null +++ b/docs/source/contributing/new_vector_database.md @@ -0,0 +1,75 @@ +# Adding a New Vector Database + +This guide will walk you through the process of adding a new vector database to Llama Stack. + +> **_NOTE:_** Here's an example Pull Request of the [Milvus Vector Database Provider](https://github.com/meta-llama/llama-stack/pull/1467). + +Vector Database providers are used to store and retrieve vector embeddings. Vector databases are not limited to vector +search but can support keyword and hybrid search. Additionally, vector database can also support operations like +filtering, sorting, and aggregating vectors. + +## Steps to Add a New Vector Database Provider +1. **Choose the Database Type**: Determine if your vector database is a remote service, inline, or both. + - Remote databases make requests to external services, while inline databases execute locally. Some providers support both. +2. **Implement the Provider**: Create a new provider class that inherits from `VectorDatabaseProvider` and implements the required methods. + - Implement methods for vector storage, retrieval, search, and any additional features your database supports. + - You will need to implement the following methods for `YourVectorIndex`: + - `YourVectorIndex.create()` + - `YourVectorIndex.initialize()` + - `YourVectorIndex.add_chunks()` + - `YourVectorIndex.delete_chunk()` + - `YourVectorIndex.query_vector()` + - `YourVectorIndex.query_keyword()` + - `YourVectorIndex.query_hybrid()` + - You will need to implement the following methods for `YourVectorIOAdapter`: + - `YourVectorIOAdapter.initialize()` + - `YourVectorIOAdapter.shutdown()` + - `YourVectorIOAdapter.list_vector_dbs()` + - `YourVectorIOAdapter.register_vector_db()` + - `YourVectorIOAdapter.unregister_vector_db()` + - `YourVectorIOAdapter.insert_chunks()` + - `YourVectorIOAdapter.query_chunks()` + - `YourVectorIOAdapter.delete_chunks()` +3. **Add to Registry**: Register your provider in the appropriate registry file. + - Update {repopath}`llama_stack/providers/registry/vector_io.py` to include your new provider. +```python +from llama_stack.providers.registry.specs import InlineProviderSpec +from llama_stack.providers.registry.api import Api + +InlineProviderSpec( + api=Api.vector_io, + provider_type="inline::milvus", + pip_packages=["pymilvus>=2.4.10"], + module="llama_stack.providers.inline.vector_io.milvus", + config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig", + api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], + description="", +), +``` +4. **Add Tests**: Create unit tests and integration tests for your provider in the `tests/` directory. + - Unit Tests + - By following the structure of the class methods, you will be able to easily run unit and integration tests for your database. + 1. You have to configure the tests for your provide in `/tests/unit/providers/vector_io/conftest.py`. + 2. Update the `vector_provider` fixture to include your provider if they are an inline provider. + 3. Create a `your_vectorprovider_index` fixture that initializes your vector index. + 4. Create a `your_vectorprovider_adapter` fixture that initializes your vector adapter. + 5. Add your provider to the `vector_io_providers` fixture dictionary. + - Please follow the naming convention of `your_vectorprovider_index` and `your_vectorprovider_adapter` as the tests require this to execute properly. + - Integration Tests + - Integration tests are located in {repopath}`tests/integration`. These tests use the python client-SDK APIs (from the `llama_stack_client` package) to test functionality. + - The two set of integration tests are: + - `tests/integration/vector_io/test_vector_io.py`: This file tests registration, insertion, and retrieval. + - `tests/integration/vector_io/test_openai_vector_stores.py`: These tests are for OpenAI-compatible vector stores and test the OpenAI API compatibility. + - You will need to update `skip_if_provider_doesnt_support_openai_vector_stores` to include your provider as well as `skip_if_provider_doesnt_support_openai_vector_stores_search` to test the appropriate search functionality. + - Running the tests in the GitHub CI + - You will need to update the `.github/workflows/integration-vector-io-tests.yml` file to include your provider. + - If your provider is a remote provider, you will also have to add a container to spin up and run it in the action. + - Updating the pyproject.yml + - If you are adding tests for the `inline` provider you will have to update the `unit` group. + - `uv add new_pip_package --group unit` + - If you are adding tests for the `remote` provider you will have to update the `test` group, which is used in the GitHub CI for integration tests. + - `uv add new_pip_package --group test` +5. **Update Documentation**: Please update the documentation for end users + - Generate the provider documentation by running {repopath}`./scripts/provider_codegen.py`. + - Update the autogenerated content in the registry/vector_io.py file with information about your provider. Please see other providers for examples. \ No newline at end of file diff --git a/docs/source/contributing/testing.md b/docs/source/contributing/testing.md deleted file mode 100644 index 47bf9dea7..000000000 --- a/docs/source/contributing/testing.md +++ /dev/null @@ -1,6 +0,0 @@ -# Testing Llama Stack - -Tests are of three different kinds: -- Unit tests -- Provider focused integration tests -- Client SDK tests diff --git a/docs/source/contributing/testing/record-replay.md b/docs/source/contributing/testing/record-replay.md new file mode 100644 index 000000000..3049d333c --- /dev/null +++ b/docs/source/contributing/testing/record-replay.md @@ -0,0 +1,234 @@ +# Record-Replay System + +Understanding how Llama Stack captures and replays API interactions for testing. + +## Overview + +The record-replay system solves a fundamental challenge in AI testing: how do you test against expensive, non-deterministic APIs without breaking the bank or dealing with flaky tests? + +The solution: intercept API calls, store real responses, and replay them later. This gives you real API behavior without the cost or variability. + +## How It Works + +### Request Hashing + +Every API request gets converted to a deterministic hash for lookup: + +```python +def normalize_request(method: str, url: str, headers: dict, body: dict) -> str: + normalized = { + "method": method.upper(), + "endpoint": urlparse(url).path, # Just the path, not full URL + "body": body, # Request parameters + } + return hashlib.sha256(json.dumps(normalized, sort_keys=True).encode()).hexdigest() +``` + +**Key insight:** The hashing is intentionally precise. Different whitespace, float precision, or parameter order produces different hashes. This prevents subtle bugs from false cache hits. + +```python +# These produce DIFFERENT hashes: +{"content": "Hello world"} +{"content": "Hello world\n"} +{"temperature": 0.7} +{"temperature": 0.7000001} +``` + +### Client Interception + +The system patches OpenAI and Ollama client methods to intercept calls before they leave your application. This happens transparently - your test code doesn't change. + +### Storage Architecture + +Recordings use a two-tier storage system optimized for both speed and debuggability: + +``` +recordings/ +β”œβ”€β”€ index.sqlite # Fast lookup by request hash +└── responses/ + β”œβ”€β”€ abc123def456.json # Individual response files + └── def789ghi012.json +``` + +**SQLite index** enables O(log n) hash lookups and metadata queries without loading response bodies. + +**JSON files** store complete request/response pairs in human-readable format for debugging. + +## Recording Modes + +### LIVE Mode + +Direct API calls with no recording or replay: + +```python +with inference_recording(mode=InferenceMode.LIVE): + response = await client.chat.completions.create(...) +``` + +Use for initial development and debugging against real APIs. + +### RECORD Mode + +Captures API interactions while passing through real responses: + +```python +with inference_recording(mode=InferenceMode.RECORD, storage_dir="./recordings"): + response = await client.chat.completions.create(...) + # Real API call made, response captured AND returned +``` + +The recording process: +1. Request intercepted and hashed +2. Real API call executed +3. Response captured and serialized +4. Recording stored to disk +5. Original response returned to caller + +### REPLAY Mode + +Returns stored responses instead of making API calls: + +```python +with inference_recording(mode=InferenceMode.REPLAY, storage_dir="./recordings"): + response = await client.chat.completions.create(...) + # No API call made, cached response returned instantly +``` + +The replay process: +1. Request intercepted and hashed +2. Hash looked up in SQLite index +3. Response loaded from JSON file +4. Response deserialized and returned +5. Error if no recording found + +## Streaming Support + +Streaming APIs present a unique challenge: how do you capture an async generator? + +### The Problem + +```python +# How do you record this? +async for chunk in client.chat.completions.create(stream=True): + process(chunk) +``` + +### The Solution + +The system captures all chunks immediately before yielding any: + +```python +async def handle_streaming_record(response): + # Capture complete stream first + chunks = [] + async for chunk in response: + chunks.append(chunk) + + # Store complete recording + storage.store_recording( + request_hash, request_data, {"body": chunks, "is_streaming": True} + ) + + # Return generator that replays captured chunks + async def replay_stream(): + for chunk in chunks: + yield chunk + + return replay_stream() +``` + +This ensures: +- **Complete capture** - The entire stream is saved atomically +- **Interface preservation** - The returned object behaves like the original API +- **Deterministic replay** - Same chunks in the same order every time + +## Serialization + +API responses contain complex Pydantic objects that need careful serialization: + +```python +def _serialize_response(response): + if hasattr(response, "model_dump"): + # Preserve type information for proper deserialization + return { + "__type__": f"{response.__class__.__module__}.{response.__class__.__qualname__}", + "__data__": response.model_dump(mode="json"), + } + return response +``` + +This preserves type safety - when replayed, you get the same Pydantic objects with all their validation and methods. + +## Environment Integration + +### Environment Variables + +Control recording behavior globally: + +```bash +export LLAMA_STACK_TEST_INFERENCE_MODE=replay +export LLAMA_STACK_TEST_RECORDING_DIR=/path/to/recordings +pytest tests/integration/ +``` + +### Pytest Integration + +The system integrates automatically based on environment variables, requiring no changes to test code. + +## Debugging Recordings + +### Inspecting Storage + +```bash +# See what's recorded +sqlite3 recordings/index.sqlite "SELECT endpoint, model, timestamp FROM recordings LIMIT 10;" + +# View specific response +cat recordings/responses/abc123def456.json | jq '.response.body' + +# Find recordings by endpoint +sqlite3 recordings/index.sqlite "SELECT * FROM recordings WHERE endpoint='/v1/chat/completions';" +``` + +### Common Issues + +**Hash mismatches:** Request parameters changed slightly between record and replay +```bash +# Compare request details +cat recordings/responses/abc123.json | jq '.request' +``` + +**Serialization errors:** Response types changed between versions +```bash +# Re-record with updated types +rm recordings/responses/failing_hash.json +LLAMA_STACK_TEST_INFERENCE_MODE=record pytest test_failing.py +``` + +**Missing recordings:** New test or changed parameters +```bash +# Record the missing interaction +LLAMA_STACK_TEST_INFERENCE_MODE=record pytest test_new.py +``` + +## Design Decisions + +### Why Not Mocks? + +Traditional mocking breaks down with AI APIs because: +- Response structures are complex and evolve frequently +- Streaming behavior is hard to mock correctly +- Edge cases in real APIs get missed +- Mocks become brittle maintenance burdens + +### Why Precise Hashing? + +Loose hashing (normalizing whitespace, rounding floats) seems convenient but hides bugs. If a test changes slightly, you want to know about it rather than accidentally getting the wrong cached response. + +### Why JSON + SQLite? + +- **JSON** - Human readable, diff-friendly, easy to inspect and modify +- **SQLite** - Fast indexed lookups without loading response bodies +- **Hybrid** - Best of both worlds for different use cases + +This system provides reliable, fast testing against real AI APIs while maintaining the ability to debug issues when they arise. \ No newline at end of file diff --git a/docs/source/distributions/k8s-benchmark/apply.sh b/docs/source/distributions/k8s-benchmark/apply.sh new file mode 100755 index 000000000..119a1c849 --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/apply.sh @@ -0,0 +1,57 @@ +#!/usr/bin/env bash + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Deploys the benchmark-specific components on top of the base k8s deployment (../k8s/apply.sh). + +export MOCK_INFERENCE_PORT=8080 +export STREAM_DELAY_SECONDS=0.005 + +export POSTGRES_USER=llamastack +export POSTGRES_DB=llamastack +export POSTGRES_PASSWORD=llamastack + +export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct +export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B + +export MOCK_INFERENCE_MODEL=mock-inference + +# Use llama-stack-benchmark-service as the benchmark server +export LOCUST_HOST=http://llama-stack-benchmark-service:8323 +export LOCUST_BASE_PATH=/v1/openai/v1 + +# Use vllm-service as the benchmark server +# export LOCUST_HOST=http://vllm-server:8000 +# export LOCUST_BASE_PATH=/v1 + + +export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL + +set -euo pipefail +set -x + +# Deploy benchmark-specific components +# Deploy OpenAI mock server +kubectl create configmap openai-mock --from-file=openai-mock-server.py \ + --dry-run=client -o yaml | kubectl apply --validate=false -f - + +envsubst < openai-mock-deployment.yaml | kubectl apply --validate=false -f - + +# Create configmap with our custom stack config +kubectl create configmap llama-stack-config --from-file=stack_run_config.yaml \ + --dry-run=client -o yaml > stack-configmap.yaml + +kubectl apply --validate=false -f stack-configmap.yaml + +# Deploy our custom llama stack server (overriding the base one) +envsubst < stack-k8s.yaml.template | kubectl apply --validate=false -f - + +# Deploy Locust load testing +kubectl create configmap locust-script --from-file=locustfile.py \ + --dry-run=client -o yaml | kubectl apply --validate=false -f - + +envsubst < locust-k8s.yaml | kubectl apply --validate=false -f - diff --git a/docs/source/distributions/k8s-benchmark/locust-k8s.yaml b/docs/source/distributions/k8s-benchmark/locust-k8s.yaml new file mode 100644 index 000000000..f20a01b2d --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/locust-k8s.yaml @@ -0,0 +1,131 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: locust-master + labels: + app: locust + role: master +spec: + replicas: 1 + selector: + matchLabels: + app: locust + role: master + template: + metadata: + labels: + app: locust + role: master + spec: + containers: + - name: locust-master + image: locustio/locust:2.31.8 + ports: + - containerPort: 8089 # Web UI + - containerPort: 5557 # Master communication + env: + - name: LOCUST_HOST + value: "${LOCUST_HOST}" + - name: LOCUST_LOCUSTFILE + value: "/locust/locustfile.py" + - name: LOCUST_WEB_HOST + value: "0.0.0.0" + - name: LOCUST_MASTER + value: "true" + - name: LOCUST_BASE_PATH + value: "${LOCUST_BASE_PATH}" + - name: INFERENCE_MODEL + value: "${BENCHMARK_INFERENCE_MODEL}" + volumeMounts: + - name: locust-script + mountPath: /locust + command: ["locust"] + args: + - "--master" + - "--web-host=0.0.0.0" + - "--web-port=8089" + - "--host=${LOCUST_HOST}" + - "--locustfile=/locust/locustfile.py" + volumes: + - name: locust-script + configMap: + name: locust-script +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: locust-worker + labels: + app: locust + role: worker +spec: + replicas: 2 # Start with 2 workers, can be scaled up + selector: + matchLabels: + app: locust + role: worker + template: + metadata: + labels: + app: locust + role: worker + spec: + containers: + - name: locust-worker + image: locustio/locust:2.31.8 + env: + - name: LOCUST_HOST + value: "${LOCUST_HOST}" + - name: LOCUST_LOCUSTFILE + value: "/locust/locustfile.py" + - name: LOCUST_MASTER_HOST + value: "locust-master-service" + - name: LOCUST_MASTER_PORT + value: "5557" + - name: INFERENCE_MODEL + value: "${BENCHMARK_INFERENCE_MODEL}" + - name: LOCUST_BASE_PATH + value: "${LOCUST_BASE_PATH}" + volumeMounts: + - name: locust-script + mountPath: /locust + command: ["locust"] + args: + - "--worker" + - "--master-host=locust-master-service" + - "--master-port=5557" + - "--locustfile=/locust/locustfile.py" + volumes: + - name: locust-script + configMap: + name: locust-script +--- +apiVersion: v1 +kind: Service +metadata: + name: locust-master-service +spec: + selector: + app: locust + role: master + ports: + - name: web-ui + port: 8089 + targetPort: 8089 + - name: master-comm + port: 5557 + targetPort: 5557 + type: ClusterIP +--- +apiVersion: v1 +kind: Service +metadata: + name: locust-web-ui +spec: + selector: + app: locust + role: master + ports: + - port: 8089 + targetPort: 8089 + type: ClusterIP # Keep internal, use port-forward to access diff --git a/docs/source/distributions/k8s-benchmark/locustfile.py b/docs/source/distributions/k8s-benchmark/locustfile.py new file mode 100644 index 000000000..8e511fa95 --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/locustfile.py @@ -0,0 +1,78 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Locust load testing script for Llama Stack with Prism mock OpenAI provider. +""" + +import random +from locust import HttpUser, task, between +import os + +base_path = os.getenv("LOCUST_BASE_PATH", "/v1/openai/v1") + +MODEL_ID = os.getenv("INFERENCE_MODEL") + +class LlamaStackUser(HttpUser): + wait_time = between(0.0, 0.0001) + + def on_start(self): + """Setup authentication and test data.""" + # No auth required for benchmark server + self.headers = { + "Content-Type": "application/json" + } + + # Test messages of varying lengths + self.test_messages = [ + [{"role": "user", "content": "Hi"}], + [{"role": "user", "content": "What is the capital of France?"}], + [{"role": "user", "content": "Explain quantum physics in simple terms."}], + [{"role": "user", "content": "Write a short story about a robot learning to paint."}], + [ + {"role": "user", "content": "What is machine learning?"}, + {"role": "assistant", "content": "Machine learning is a subset of AI..."}, + {"role": "user", "content": "Can you give me a practical example?"} + ] + ] + + @task(weight=100) + def chat_completion_streaming(self): + """Test streaming chat completion (20% of requests).""" + messages = random.choice(self.test_messages) + payload = { + "model": MODEL_ID, + "messages": messages, + "stream": True, + "max_tokens": 100 + } + + with self.client.post( + f"{base_path}/chat/completions", + headers=self.headers, + json=payload, + stream=True, + catch_response=True + ) as response: + if response.status_code == 200: + chunks_received = 0 + try: + for line in response.iter_lines(): + if line: + line_str = line.decode('utf-8') + if line_str.startswith('data: '): + chunks_received += 1 + if line_str.strip() == 'data: [DONE]': + break + + if chunks_received > 0: + response.success() + else: + response.failure("No streaming chunks received") + except Exception as e: + response.failure(f"Streaming error: {e}") + else: + response.failure(f"HTTP {response.status_code}: {response.text}") diff --git a/docs/source/distributions/k8s-benchmark/openai-mock-deployment.yaml b/docs/source/distributions/k8s-benchmark/openai-mock-deployment.yaml new file mode 100644 index 000000000..c72921281 --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/openai-mock-deployment.yaml @@ -0,0 +1,52 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: openai-mock + labels: + app: openai-mock +spec: + replicas: 1 + selector: + matchLabels: + app: openai-mock + template: + metadata: + labels: + app: openai-mock + spec: + containers: + - name: openai-mock + image: python:3.12-slim + ports: + - containerPort: ${MOCK_INFERENCE_PORT} + env: + - name: PORT + value: "${MOCK_INFERENCE_PORT}" + - name: MOCK_MODELS + value: "${MOCK_INFERENCE_MODEL}" + - name: STREAM_DELAY_SECONDS + value: "${STREAM_DELAY_SECONDS}" + command: ["sh", "-c"] + args: + - | + pip install flask && + python /app/openai-mock-server.py --port ${MOCK_INFERENCE_PORT} + volumeMounts: + - name: openai-mock-script + mountPath: /app + volumes: + - name: openai-mock-script + configMap: + name: openai-mock +--- +apiVersion: v1 +kind: Service +metadata: + name: openai-mock-service +spec: + selector: + app: openai-mock + ports: + - port: 8080 + targetPort: 8080 + type: ClusterIP diff --git a/docs/source/distributions/k8s-benchmark/openai-mock-server.py b/docs/source/distributions/k8s-benchmark/openai-mock-server.py new file mode 100755 index 000000000..46c923b60 --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/openai-mock-server.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +OpenAI-compatible mock server that returns: +- Hardcoded /models response for consistent validation +- Valid OpenAI-formatted chat completion responses with dynamic content +""" + +from flask import Flask, request, jsonify, Response +import time +import random +import uuid +import json +import argparse +import os + +app = Flask(__name__) + +# Models from environment variables +def get_models(): + models_str = os.getenv("MOCK_MODELS", "mock-inference") + model_ids = [m.strip() for m in models_str.split(",") if m.strip()] + + return { + "object": "list", + "data": [ + { + "id": model_id, + "object": "model", + "created": 1234567890, + "owned_by": "vllm" + } + for model_id in model_ids + ] + } + +def generate_random_text(length=50): + """Generate random but coherent text for responses.""" + words = [ + "Hello", "there", "I'm", "an", "AI", "assistant", "ready", "to", "help", "you", + "with", "your", "questions", "and", "tasks", "today", "Let", "me","know", "what", + "you'd", "like", "to", "discuss", "or", "explore", "together", "I", "can", "assist", + "with", "various", "topics", "including", "coding", "writing", "analysis", "and", "more" + ] + return " ".join(random.choices(words, k=length)) + +@app.route('/models', methods=['GET']) +def list_models(): + models = get_models() + print(f"[MOCK] Returning models: {[m['id'] for m in models['data']]}") + return jsonify(models) + +@app.route('/chat/completions', methods=['POST']) +def chat_completions(): + """Return OpenAI-formatted chat completion responses.""" + data = request.get_json() + default_model = get_models()['data'][0]['id'] + model = data.get('model', default_model) + messages = data.get('messages', []) + stream = data.get('stream', False) + + print(f"[MOCK] Chat completion request - model: {model}, stream: {stream}") + + if stream: + return handle_streaming_completion(model, messages) + else: + return handle_non_streaming_completion(model, messages) + +def handle_non_streaming_completion(model, messages): + response_text = generate_random_text(random.randint(20, 80)) + + # Calculate realistic token counts + prompt_tokens = sum(len(str(msg.get('content', '')).split()) for msg in messages) + completion_tokens = len(response_text.split()) + + response = { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": response_text + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens + } + } + + return jsonify(response) + +def handle_streaming_completion(model, messages): + def generate_stream(): + # Generate response text + full_response = generate_random_text(random.randint(30, 100)) + words = full_response.split() + + # Send initial chunk + initial_chunk = { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": ""} + } + ] + } + yield f"data: {json.dumps(initial_chunk)}\n\n" + + # Send word by word + for i, word in enumerate(words): + chunk = { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": {"content": f"{word} " if i < len(words) - 1 else word} + } + ] + } + yield f"data: {json.dumps(chunk)}\n\n" + # Configurable delay to simulate realistic streaming + stream_delay = float(os.getenv("STREAM_DELAY_SECONDS", "0.005")) + time.sleep(stream_delay) + + # Send final chunk + final_chunk = { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": {"content": ""}, + "finish_reason": "stop" + } + ] + } + yield f"data: {json.dumps(final_chunk)}\n\n" + yield "data: [DONE]\n\n" + + return Response( + generate_stream(), + mimetype='text/event-stream', + headers={ + 'Cache-Control': 'no-cache', + 'Connection': 'keep-alive', + 'Access-Control-Allow-Origin': '*', + } + ) + +@app.route('/health', methods=['GET']) +def health(): + return jsonify({"status": "healthy", "type": "openai-mock"}) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='OpenAI-compatible mock server') + parser.add_argument('--port', type=int, default=8081, + help='Port to run the server on (default: 8081)') + args = parser.parse_args() + + port = args.port + + models = get_models() + print("Starting OpenAI-compatible mock server...") + print(f"- /models endpoint with: {[m['id'] for m in models['data']]}") + print("- OpenAI-formatted chat/completion responses with dynamic content") + print("- Streaming support with valid SSE format") + print(f"- Listening on: http://0.0.0.0:{port}") + app.run(host='0.0.0.0', port=port, debug=False) diff --git a/docs/source/distributions/k8s-benchmark/stack-configmap.yaml b/docs/source/distributions/k8s-benchmark/stack-configmap.yaml new file mode 100644 index 000000000..653e66756 --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/stack-configmap.yaml @@ -0,0 +1,143 @@ +apiVersion: v1 +data: + stack_run_config.yaml: | + version: '2' + image_name: kubernetes-benchmark-demo + apis: + - agents + - inference + - safety + - telemetry + - tool_runtime + - vector_io + providers: + inference: + - provider_id: vllm-inference + provider_type: remote::vllm + config: + url: ${env.VLLM_URL:=http://localhost:8000/v1} + max_tokens: ${env.VLLM_MAX_TOKENS:=4096} + api_token: ${env.VLLM_API_TOKEN:=fake} + tls_verify: ${env.VLLM_TLS_VERIFY:=true} + - provider_id: vllm-safety + provider_type: remote::vllm + config: + url: ${env.VLLM_SAFETY_URL:=http://localhost:8000/v1} + max_tokens: ${env.VLLM_MAX_TOKENS:=4096} + api_token: ${env.VLLM_API_TOKEN:=fake} + tls_verify: ${env.VLLM_TLS_VERIFY:=true} + - provider_id: mock-vllm-inference + provider_type: remote::vllm + config: + url: http://openai-mock-service:${env.MOCK_INFERENCE_PORT} + max_tokens: 4096 + api_token: fake + tls_verify: false + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers + config: {} + vector_io: + - provider_id: ${env.ENABLE_CHROMADB:+chromadb} + provider_type: remote::chromadb + config: + url: ${env.CHROMADB_URL:=} + kvstore: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: + excluded_categories: [] + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} + responses_store: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + service_name: "${env.OTEL_SERVICE_NAME:=\u200B}" + sinks: ${env.TELEMETRY_SINKS:=console} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:+} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:+} + max_results: 3 + - provider_id: rag-runtime + provider_type: inline::rag-runtime + config: {} + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + config: {} + metadata_store: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} + table_name: llamastack_kvstore + inference_store: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} + models: + - metadata: + embedding_dimension: 384 + model_id: all-MiniLM-L6-v2 + provider_id: sentence-transformers + model_type: embedding + - model_id: ${env.INFERENCE_MODEL} + provider_id: vllm-inference + model_type: llm + - model_id: ${env.SAFETY_MODEL} + provider_id: vllm-safety + model_type: llm + - model_id: ${env.MOCK_INFERENCE_MODEL} + provider_id: mock-vllm-inference + model_type: llm + shields: + - shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B} + vector_dbs: [] + datasets: [] + scoring_fns: [] + benchmarks: [] + tool_groups: + - toolgroup_id: builtin::websearch + provider_id: tavily-search + - toolgroup_id: builtin::rag + provider_id: rag-runtime + server: + port: 8323 +kind: ConfigMap +metadata: + creationTimestamp: null + name: llama-stack-config diff --git a/docs/source/distributions/k8s-benchmark/stack-k8s.yaml.template b/docs/source/distributions/k8s-benchmark/stack-k8s.yaml.template new file mode 100644 index 000000000..bc14d5124 --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/stack-k8s.yaml.template @@ -0,0 +1,87 @@ +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: llama-benchmark-pvc +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: 1Gi +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: llama-stack-benchmark-server +spec: + replicas: 1 + selector: + matchLabels: + app.kubernetes.io/name: llama-stack-benchmark + app.kubernetes.io/component: server + template: + metadata: + labels: + app.kubernetes.io/name: llama-stack-benchmark + app.kubernetes.io/component: server + spec: + containers: + - name: llama-stack-benchmark + image: llamastack/distribution-starter:latest + imagePullPolicy: Always # since we have specified latest instead of a version + env: + - name: ENABLE_CHROMADB + value: "true" + - name: CHROMADB_URL + value: http://chromadb.default.svc.cluster.local:6000 + - name: POSTGRES_HOST + value: postgres-server.default.svc.cluster.local + - name: POSTGRES_PORT + value: "5432" + - name: INFERENCE_MODEL + value: "${INFERENCE_MODEL}" + - name: SAFETY_MODEL + value: "${SAFETY_MODEL}" + - name: TAVILY_SEARCH_API_KEY + value: "${TAVILY_SEARCH_API_KEY}" + - name: MOCK_INFERENCE_PORT + value: "${MOCK_INFERENCE_PORT}" + - name: VLLM_URL + value: http://vllm-server.default.svc.cluster.local:8000/v1 + - name: VLLM_MAX_TOKENS + value: "3072" + - name: VLLM_SAFETY_URL + value: http://vllm-server-safety.default.svc.cluster.local:8001/v1 + - name: VLLM_TLS_VERIFY + value: "false" + - name: MOCK_INFERENCE_MODEL + value: "${MOCK_INFERENCE_MODEL}" + command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8323"] + ports: + - containerPort: 8323 + volumeMounts: + - name: llama-storage + mountPath: /root/.llama + - name: llama-config + mountPath: /etc/config + volumes: + - name: llama-storage + persistentVolumeClaim: + claimName: llama-benchmark-pvc + - name: llama-config + configMap: + name: llama-stack-config +--- +apiVersion: v1 +kind: Service +metadata: + name: llama-stack-benchmark-service +spec: + selector: + app.kubernetes.io/name: llama-stack-benchmark + app.kubernetes.io/component: server + ports: + - name: http + port: 8323 + targetPort: 8323 + type: ClusterIP diff --git a/docs/source/distributions/k8s-benchmark/stack_run_config.yaml b/docs/source/distributions/k8s-benchmark/stack_run_config.yaml new file mode 100644 index 000000000..ad56be047 --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/stack_run_config.yaml @@ -0,0 +1,136 @@ +version: '2' +image_name: kubernetes-benchmark-demo +apis: +- agents +- inference +- safety +- telemetry +- tool_runtime +- vector_io +providers: + inference: + - provider_id: vllm-inference + provider_type: remote::vllm + config: + url: ${env.VLLM_URL:=http://localhost:8000/v1} + max_tokens: ${env.VLLM_MAX_TOKENS:=4096} + api_token: ${env.VLLM_API_TOKEN:=fake} + tls_verify: ${env.VLLM_TLS_VERIFY:=true} + - provider_id: vllm-safety + provider_type: remote::vllm + config: + url: ${env.VLLM_SAFETY_URL:=http://localhost:8000/v1} + max_tokens: ${env.VLLM_MAX_TOKENS:=4096} + api_token: ${env.VLLM_API_TOKEN:=fake} + tls_verify: ${env.VLLM_TLS_VERIFY:=true} + - provider_id: mock-vllm-inference + provider_type: remote::vllm + config: + url: http://openai-mock-service:${env.MOCK_INFERENCE_PORT} + max_tokens: 4096 + api_token: fake + tls_verify: false + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers + config: {} + vector_io: + - provider_id: ${env.ENABLE_CHROMADB:+chromadb} + provider_type: remote::chromadb + config: + url: ${env.CHROMADB_URL:=} + kvstore: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: + excluded_categories: [] + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} + responses_store: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + service_name: "${env.OTEL_SERVICE_NAME:=\u200B}" + sinks: ${env.TELEMETRY_SINKS:=console} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:+} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:+} + max_results: 3 + - provider_id: rag-runtime + provider_type: inline::rag-runtime + config: {} + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + config: {} +metadata_store: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} + table_name: llamastack_kvstore +inference_store: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} +models: +- metadata: + embedding_dimension: 384 + model_id: all-MiniLM-L6-v2 + provider_id: sentence-transformers + model_type: embedding +- model_id: ${env.INFERENCE_MODEL} + provider_id: vllm-inference + model_type: llm +- model_id: ${env.SAFETY_MODEL} + provider_id: vllm-safety + model_type: llm +- model_id: ${env.MOCK_INFERENCE_MODEL} + provider_id: mock-vllm-inference + model_type: llm +shields: +- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B} +vector_dbs: [] +datasets: [] +scoring_fns: [] +benchmarks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::rag + provider_id: rag-runtime +server: + port: 8323 diff --git a/docs/source/distributions/k8s/stack-k8s.yaml.template b/docs/source/distributions/k8s/stack-k8s.yaml.template index ad5d2c716..dfc049f4f 100644 --- a/docs/source/distributions/k8s/stack-k8s.yaml.template +++ b/docs/source/distributions/k8s/stack-k8s.yaml.template @@ -40,19 +40,19 @@ spec: value: "3072" - name: VLLM_SAFETY_URL value: http://vllm-server-safety.default.svc.cluster.local:8001/v1 + - name: VLLM_TLS_VERIFY + value: "false" - name: POSTGRES_HOST value: postgres-server.default.svc.cluster.local - name: POSTGRES_PORT value: "5432" - - name: VLLM_TLS_VERIFY - value: "false" - name: INFERENCE_MODEL value: "${INFERENCE_MODEL}" - name: SAFETY_MODEL value: "${SAFETY_MODEL}" - name: TAVILY_SEARCH_API_KEY value: "${TAVILY_SEARCH_API_KEY}" - command: ["python", "-m", "llama_stack.core.server.server", "--config", "/etc/config/stack_run_config.yaml", "--port", "8321"] + command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8321"] ports: - containerPort: 8321 volumeMounts: diff --git a/docs/source/providers/external/external-providers-guide.md b/docs/source/providers/external/external-providers-guide.md index 2479d406f..e2d4ebea9 100644 --- a/docs/source/providers/external/external-providers-guide.md +++ b/docs/source/providers/external/external-providers-guide.md @@ -226,7 +226,7 @@ uv init name = "llama-stack-provider-ollama" version = "0.1.0" description = "Ollama provider for Llama Stack" -requires-python = ">=3.10" +requires-python = ">=3.12" dependencies = ["llama-stack", "pydantic", "ollama", "aiohttp"] ``` diff --git a/docs/source/providers/inference/index.md b/docs/source/providers/inference/index.md index 1c7bc86b9..38781e5eb 100644 --- a/docs/source/providers/inference/index.md +++ b/docs/source/providers/inference/index.md @@ -29,6 +29,7 @@ remote_runpod remote_sambanova remote_tgi remote_together +remote_vertexai remote_vllm remote_watsonx ``` diff --git a/docs/source/providers/inference/remote_vertexai.md b/docs/source/providers/inference/remote_vertexai.md new file mode 100644 index 000000000..962bbd76f --- /dev/null +++ b/docs/source/providers/inference/remote_vertexai.md @@ -0,0 +1,40 @@ +# remote::vertexai + +## Description + +Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages: + +β€’ Enterprise-grade security: Uses Google Cloud's security controls and IAM +β€’ Better integration: Seamless integration with other Google Cloud services +β€’ Advanced features: Access to additional Vertex AI features like model tuning and monitoring +β€’ Authentication: Uses Google Cloud Application Default Credentials (ADC) instead of API keys + +Configuration: +- Set VERTEX_AI_PROJECT environment variable (required) +- Set VERTEX_AI_LOCATION environment variable (optional, defaults to us-central1) +- Use Google Cloud Application Default Credentials or service account key + +Authentication Setup: +Option 1 (Recommended): gcloud auth application-default login +Option 2: Set GOOGLE_APPLICATION_CREDENTIALS to service account key path + +Available Models: +- vertex_ai/gemini-2.0-flash +- vertex_ai/gemini-2.5-flash +- vertex_ai/gemini-2.5-pro + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `project` | `` | No | | Google Cloud project ID for Vertex AI | +| `location` | `` | No | us-central1 | Google Cloud location for Vertex AI | + +## Sample Configuration + +```yaml +project: ${env.VERTEX_AI_PROJECT:=} +location: ${env.VERTEX_AI_LOCATION:=us-central1} + +``` + diff --git a/docs/source/providers/vector_io/inline_faiss.md b/docs/source/providers/vector_io/inline_faiss.md index bcff66f3f..cfa18a839 100644 --- a/docs/source/providers/vector_io/inline_faiss.md +++ b/docs/source/providers/vector_io/inline_faiss.md @@ -12,6 +12,18 @@ That means you'll get fast and efficient vector retrieval. - Lightweight and easy to use - Fully integrated with Llama Stack - GPU support +- **Vector search** - FAISS supports pure vector similarity search using embeddings + +## Search Modes + +**Supported:** +- **Vector Search** (`mode="vector"`): Performs vector similarity search using embeddings + +**Not Supported:** +- **Keyword Search** (`mode="keyword"`): Not supported by FAISS +- **Hybrid Search** (`mode="hybrid"`): Not supported by FAISS + +> **Note**: FAISS is designed as a pure vector similarity search library. See the [FAISS GitHub repository](https://github.com/facebookresearch/faiss) for more details about FAISS's core functionality. ## Usage diff --git a/docs/source/providers/vector_io/inline_meta-reference.md b/docs/source/providers/vector_io/inline_meta-reference.md index 0aac445bd..6f269c441 100644 --- a/docs/source/providers/vector_io/inline_meta-reference.md +++ b/docs/source/providers/vector_io/inline_meta-reference.md @@ -21,5 +21,7 @@ kvstore: ## Deprecation Notice -⚠️ **Warning**: Please use the `inline::faiss` provider instead. +```{warning} +Please use the `inline::faiss` provider instead. +``` diff --git a/docs/source/providers/vector_io/inline_sqlite_vec.md b/docs/source/providers/vector_io/inline_sqlite_vec.md index 7ad8eb252..9e5654a50 100644 --- a/docs/source/providers/vector_io/inline_sqlite_vec.md +++ b/docs/source/providers/vector_io/inline_sqlite_vec.md @@ -25,5 +25,7 @@ kvstore: ## Deprecation Notice -⚠️ **Warning**: Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead. +```{warning} +Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead. +``` diff --git a/docs/source/providers/vector_io/remote_milvus.md b/docs/source/providers/vector_io/remote_milvus.md index 3646f4acc..075423d04 100644 --- a/docs/source/providers/vector_io/remote_milvus.md +++ b/docs/source/providers/vector_io/remote_milvus.md @@ -11,6 +11,7 @@ That means you're not limited to storing vectors in memory or in a separate serv - Easy to use - Fully integrated with Llama Stack +- Supports all search modes: vector, keyword, and hybrid search (both inline and remote configurations) ## Usage @@ -101,6 +102,92 @@ vector_io: - **`client_pem_path`**: Path to the **client certificate** file (required for mTLS). - **`client_key_path`**: Path to the **client private key** file (required for mTLS). +## Search Modes + +Milvus supports three different search modes for both inline and remote configurations: + +### Vector Search +Vector search uses semantic similarity to find the most relevant chunks based on embedding vectors. This is the default search mode and works well for finding conceptually similar content. + +```python +# Vector search example +search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="What is machine learning?", + search_mode="vector", + max_num_results=5, +) +``` + +### Keyword Search +Keyword search uses traditional text-based matching to find chunks containing specific terms or phrases. This is useful when you need exact term matches. + +```python +# Keyword search example +search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="Python programming language", + search_mode="keyword", + max_num_results=5, +) +``` + +### Hybrid Search +Hybrid search combines both vector and keyword search methods to provide more comprehensive results. It leverages the strengths of both semantic similarity and exact term matching. + +#### Basic Hybrid Search +```python +# Basic hybrid search example (uses RRF ranker with default impact_factor=60.0) +search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="neural networks in Python", + search_mode="hybrid", + max_num_results=5, +) +``` + +**Note**: The default `impact_factor` value of 60.0 was empirically determined to be optimal in the original RRF research paper: ["Reciprocal Rank Fusion outperforms Condorcet and individual Rank Learning Methods"](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf) (Cormack et al., 2009). + +#### Hybrid Search with RRF (Reciprocal Rank Fusion) Ranker +RRF combines rankings from vector and keyword search by using reciprocal ranks. The impact factor controls how much weight is given to higher-ranked results. + +```python +# Hybrid search with custom RRF parameters +search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="neural networks in Python", + search_mode="hybrid", + max_num_results=5, + ranking_options={ + "ranker": { + "type": "rrf", + "impact_factor": 100.0, # Higher values give more weight to top-ranked results + } + }, +) +``` + +#### Hybrid Search with Weighted Ranker +Weighted ranker linearly combines normalized scores from vector and keyword search. The alpha parameter controls the balance between the two search methods. + +```python +# Hybrid search with weighted ranker +search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="neural networks in Python", + search_mode="hybrid", + max_num_results=5, + ranking_options={ + "ranker": { + "type": "weighted", + "alpha": 0.7, # 70% vector search, 30% keyword search + } + }, +) +``` + +For detailed documentation on RRF and Weighted rankers, please refer to the [Milvus Reranking Guide](https://milvus.io/docs/reranking.md). + ## Documentation See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general. @@ -117,7 +204,10 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi | `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend | | `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. | -> **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider. +```{note} + This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider. + ``` + ## Sample Configuration diff --git a/docs/source/references/llama_cli_reference/download_models.md b/docs/source/references/llama_cli_reference/download_models.md index e32099023..a9af65349 100644 --- a/docs/source/references/llama_cli_reference/download_models.md +++ b/docs/source/references/llama_cli_reference/download_models.md @@ -128,7 +128,9 @@ llama download --source huggingface --model-id Prompt-Guard-86M --ignore-pattern **Important:** Set your environment variable `HF_TOKEN` or pass in `--hf-token` to the command to validate your access. You can find your token at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens). -> **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored. +```{tip} +Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored. +``` ## List the downloaded models diff --git a/docs/source/references/llama_cli_reference/index.md b/docs/source/references/llama_cli_reference/index.md index 4ef76fe7d..09a8b7177 100644 --- a/docs/source/references/llama_cli_reference/index.md +++ b/docs/source/references/llama_cli_reference/index.md @@ -152,7 +152,9 @@ llama download --source huggingface --model-id Prompt-Guard-86M --ignore-pattern **Important:** Set your environment variable `HF_TOKEN` or pass in `--hf-token` to the command to validate your access. You can find your token at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens). -> **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored. +```{tip} +Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored. +``` ## List the downloaded models diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index e816da766..7dd3e9289 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -706,6 +706,7 @@ class Agents(Protocol): temperature: float | None = None, text: OpenAIResponseText | None = None, tools: list[OpenAIResponseInputTool] | None = None, + include: list[str] | None = None, max_infer_iters: int | None = 10, # this is an extension to the OpenAI API ) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]: """Create a new OpenAI response. @@ -713,6 +714,7 @@ class Agents(Protocol): :param input: Input message(s) to create the response. :param model: The underlying LLM used for completions. :param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses. + :param include: (Optional) Additional fields to include in the response. :returns: An OpenAIResponseObject. """ ... diff --git a/llama_stack/apis/agents/openai_responses.py b/llama_stack/apis/agents/openai_responses.py index 10cadf38f..591992479 100644 --- a/llama_stack/apis/agents/openai_responses.py +++ b/llama_stack/apis/agents/openai_responses.py @@ -170,6 +170,23 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel): type: Literal["web_search_call"] = "web_search_call" +class OpenAIResponseOutputMessageFileSearchToolCallResults(BaseModel): + """Search results returned by the file search operation. + + :param attributes: (Optional) Key-value attributes associated with the file + :param file_id: Unique identifier of the file containing the result + :param filename: Name of the file containing the result + :param score: Relevance score for this search result (between 0 and 1) + :param text: Text content of the search result + """ + + attributes: dict[str, Any] + file_id: str + filename: str + score: float + text: str + + @json_schema_type class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel): """File search tool call output message for OpenAI responses. @@ -185,7 +202,7 @@ class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel): queries: list[str] status: str type: Literal["file_search_call"] = "file_search_call" - results: list[dict[str, Any]] | None = None + results: list[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None @json_schema_type @@ -606,6 +623,62 @@ class OpenAIResponseObjectStreamResponseMcpCallCompleted(BaseModel): type: Literal["response.mcp_call.completed"] = "response.mcp_call.completed" +@json_schema_type +class OpenAIResponseContentPartOutputText(BaseModel): + type: Literal["output_text"] = "output_text" + text: str + # TODO: add annotations, logprobs, etc. + + +@json_schema_type +class OpenAIResponseContentPartRefusal(BaseModel): + type: Literal["refusal"] = "refusal" + refusal: str + + +OpenAIResponseContentPart = Annotated[ + OpenAIResponseContentPartOutputText | OpenAIResponseContentPartRefusal, + Field(discriminator="type"), +] +register_schema(OpenAIResponseContentPart, name="OpenAIResponseContentPart") + + +@json_schema_type +class OpenAIResponseObjectStreamResponseContentPartAdded(BaseModel): + """Streaming event for when a new content part is added to a response item. + + :param response_id: Unique identifier of the response containing this content + :param item_id: Unique identifier of the output item containing this content part + :param part: The content part that was added + :param sequence_number: Sequential number for ordering streaming events + :param type: Event type identifier, always "response.content_part.added" + """ + + response_id: str + item_id: str + part: OpenAIResponseContentPart + sequence_number: int + type: Literal["response.content_part.added"] = "response.content_part.added" + + +@json_schema_type +class OpenAIResponseObjectStreamResponseContentPartDone(BaseModel): + """Streaming event for when a content part is completed. + + :param response_id: Unique identifier of the response containing this content + :param item_id: Unique identifier of the output item containing this content part + :param part: The completed content part + :param sequence_number: Sequential number for ordering streaming events + :param type: Event type identifier, always "response.content_part.done" + """ + + response_id: str + item_id: str + part: OpenAIResponseContentPart + sequence_number: int + type: Literal["response.content_part.done"] = "response.content_part.done" + + OpenAIResponseObjectStream = Annotated[ OpenAIResponseObjectStreamResponseCreated | OpenAIResponseObjectStreamResponseOutputItemAdded @@ -625,6 +698,8 @@ OpenAIResponseObjectStream = Annotated[ | OpenAIResponseObjectStreamResponseMcpCallInProgress | OpenAIResponseObjectStreamResponseMcpCallFailed | OpenAIResponseObjectStreamResponseMcpCallCompleted + | OpenAIResponseObjectStreamResponseContentPartAdded + | OpenAIResponseObjectStreamResponseContentPartDone | OpenAIResponseObjectStreamResponseCompleted, Field(discriminator="type"), ] diff --git a/llama_stack/apis/common/errors.py b/llama_stack/apis/common/errors.py index 95d6ac18e..6e0fa0b3c 100644 --- a/llama_stack/apis/common/errors.py +++ b/llama_stack/apis/common/errors.py @@ -62,3 +62,13 @@ class SessionNotFoundError(ValueError): def __init__(self, session_name: str) -> None: message = f"Session '{session_name}' not found or access denied." super().__init__(message) + + +class ModelTypeError(TypeError): + """raised when a model is present but not the correct type""" + + def __init__(self, model_name: str, model_type: str, expected_model_type: str) -> None: + message = ( + f"Model '{model_name}' is of type '{model_type}' rather than the expected type '{expected_model_type}'" + ) + super().__init__(message) diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index 3f374460b..25ee03ec1 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum, StrEnum +from enum import Enum from typing import Any, Protocol, runtime_checkable from pydantic import BaseModel, Field @@ -15,27 +15,6 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod -# OpenAI Categories to return in the response -class OpenAICategories(StrEnum): - """ - Required set of categories in moderations api response - """ - - VIOLENCE = "violence" - VIOLENCE_GRAPHIC = "violence/graphic" - HARRASMENT = "harassment" - HARRASMENT_THREATENING = "harassment/threatening" - HATE = "hate" - HATE_THREATENING = "hate/threatening" - ILLICIT = "illicit" - ILLICIT_VIOLENT = "illicit/violent" - SEXUAL = "sexual" - SEXUAL_MINORS = "sexual/minors" - SELF_HARM = "self-harm" - SELF_HARM_INTENT = "self-harm/intent" - SELF_HARM_INSTRUCTIONS = "self-harm/instructions" - - @json_schema_type class ModerationObjectResults(BaseModel): """A moderation object. @@ -43,20 +22,6 @@ class ModerationObjectResults(BaseModel): :param categories: A list of the categories, and whether they are flagged or not. :param category_applied_input_types: A list of the categories along with the input type(s) that the score applies to. :param category_scores: A list of the categories along with their scores as predicted by model. - Required set of categories that need to be in response - - violence - - violence/graphic - - harassment - - harassment/threatening - - hate - - hate/threatening - - illicit - - illicit/violent - - sexual - - sexual/minors - - self-harm - - self-harm/intent - - self-harm/instructions """ flagged: bool diff --git a/llama_stack/core/build.py b/llama_stack/core/build.py index b3e35ecef..4b20588fd 100644 --- a/llama_stack/core/build.py +++ b/llama_stack/core/build.py @@ -91,7 +91,7 @@ def get_provider_dependencies( def print_pip_install_help(config: BuildConfig): - normal_deps, special_deps = get_provider_dependencies(config) + normal_deps, special_deps, _ = get_provider_dependencies(config) cprint( f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}", diff --git a/llama_stack/core/library_client.py b/llama_stack/core/library_client.py index 5fbbf1aff..a93fe509e 100644 --- a/llama_stack/core/library_client.py +++ b/llama_stack/core/library_client.py @@ -380,8 +380,17 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): json_content = json.dumps(convert_pydantic_to_json_value(result)) filtered_body = {k: v for k, v in body.items() if not isinstance(v, LibraryClientUploadFile)} + + status_code = httpx.codes.OK + + if options.method.upper() == "DELETE" and result is None: + status_code = httpx.codes.NO_CONTENT + + if status_code == httpx.codes.NO_CONTENT: + json_content = "" + mock_response = httpx.Response( - status_code=httpx.codes.OK, + status_code=status_code, content=json_content.encode("utf-8"), headers={ "Content-Type": "application/json", diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index 79ab7c34f..6a3f07247 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -18,7 +18,7 @@ from llama_stack.apis.common.content_types import ( InterleavedContent, InterleavedContentItem, ) -from llama_stack.apis.common.errors import ModelNotFoundError +from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError from llama_stack.apis.inference import ( BatchChatCompletionResponse, BatchCompletionResponse, @@ -65,7 +65,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus, Routin from llama_stack.providers.utils.inference.inference_store import InferenceStore from llama_stack.providers.utils.telemetry.tracing import get_current_span -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="inference") class InferenceRouter(Inference): @@ -177,6 +177,15 @@ class InferenceRouter(Inference): encoded = self.formatter.encode_content(messages) return len(encoded.tokens) if encoded and encoded.tokens else 0 + async def _get_model(self, model_id: str, expected_model_type: str) -> Model: + """takes a model id and gets model after ensuring that it is accessible and of the correct type""" + model = await self.routing_table.get_model(model_id) + if model is None: + raise ModelNotFoundError(model_id) + if model.model_type != expected_model_type: + raise ModelTypeError(model_id, model.model_type, expected_model_type) + return model + async def chat_completion( self, model_id: str, @@ -195,11 +204,7 @@ class InferenceRouter(Inference): ) if sampling_params is None: sampling_params = SamplingParams() - model = await self.routing_table.get_model(model_id) - if model is None: - raise ModelNotFoundError(model_id) - if model.model_type == ModelType.embedding: - raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") + model = await self._get_model(model_id, ModelType.llm) if tool_config: if tool_choice and tool_choice != tool_config.tool_choice: raise ValueError("tool_choice and tool_config.tool_choice must match") @@ -301,11 +306,7 @@ class InferenceRouter(Inference): logger.debug( f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}", ) - model = await self.routing_table.get_model(model_id) - if model is None: - raise ModelNotFoundError(model_id) - if model.model_type == ModelType.embedding: - raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") + model = await self._get_model(model_id, ModelType.llm) provider = await self.routing_table.get_provider_impl(model_id) params = dict( model_id=model_id, @@ -355,11 +356,7 @@ class InferenceRouter(Inference): task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: logger.debug(f"InferenceRouter.embeddings: {model_id}") - model = await self.routing_table.get_model(model_id) - if model is None: - raise ModelNotFoundError(model_id) - if model.model_type == ModelType.llm: - raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings") + await self._get_model(model_id, ModelType.embedding) provider = await self.routing_table.get_provider_impl(model_id) return await provider.embeddings( model_id=model_id, @@ -395,12 +392,7 @@ class InferenceRouter(Inference): logger.debug( f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}", ) - model_obj = await self.routing_table.get_model(model) - if model_obj is None: - raise ModelNotFoundError(model) - if model_obj.model_type == ModelType.embedding: - raise ValueError(f"Model '{model}' is an embedding model and does not support completions") - + model_obj = await self._get_model(model, ModelType.llm) params = dict( model=model_obj.identifier, prompt=prompt, @@ -476,11 +468,7 @@ class InferenceRouter(Inference): logger.debug( f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}", ) - model_obj = await self.routing_table.get_model(model) - if model_obj is None: - raise ModelNotFoundError(model) - if model_obj.model_type == ModelType.embedding: - raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions") + model_obj = await self._get_model(model, ModelType.llm) # Use the OpenAI client for a bit of extra input validation without # exposing the OpenAI client itself as part of our API surface @@ -567,12 +555,7 @@ class InferenceRouter(Inference): logger.debug( f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}", ) - model_obj = await self.routing_table.get_model(model) - if model_obj is None: - raise ModelNotFoundError(model) - if model_obj.model_type != ModelType.embedding: - raise ValueError(f"Model '{model}' is not an embedding model") - + model_obj = await self._get_model(model, ModelType.embedding) params = dict( model=model_obj.identifier, input=input, @@ -871,4 +854,5 @@ class InferenceRouter(Inference): model=model.identifier, object="chat.completion", ) + logger.debug(f"InferenceRouter.completion_response: {final_response}") await self.store.store_chat_completion(final_response, messages) diff --git a/llama_stack/core/routers/safety.py b/llama_stack/core/routers/safety.py index 9bf2b1bac..c76673d2a 100644 --- a/llama_stack/core/routers/safety.py +++ b/llama_stack/core/routers/safety.py @@ -10,7 +10,7 @@ from llama_stack.apis.inference import ( Message, ) from llama_stack.apis.safety import RunShieldResponse, Safety -from llama_stack.apis.safety.safety import ModerationObject, OpenAICategories +from llama_stack.apis.safety.safety import ModerationObject from llama_stack.apis.shields import Shield from llama_stack.log import get_logger from llama_stack.providers.datatypes import RoutingTable @@ -82,20 +82,5 @@ class SafetyRouter(Safety): input=input, model=model, ) - self._validate_required_categories_exist(response) return response - - def _validate_required_categories_exist(self, response: ModerationObject) -> None: - """Validate the ProviderImpl response contains the required Open AI moderations categories.""" - required_categories = list(map(str, OpenAICategories)) - - categories = response.results[0].categories - category_applied_input_types = response.results[0].category_applied_input_types - category_scores = response.results[0].category_scores - - for i in [categories, category_applied_input_types, category_scores]: - if not set(required_categories).issubset(set(i.keys())): - raise ValueError( - f"ProviderImpl response is missing required categories: {set(required_categories) - set(i.keys())}" - ) diff --git a/llama_stack/core/routing_tables/models.py b/llama_stack/core/routing_tables/models.py index c76619271..34c431e00 100644 --- a/llama_stack/core/routing_tables/models.py +++ b/llama_stack/core/routing_tables/models.py @@ -63,6 +63,8 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): async def get_provider_impl(self, model_id: str) -> Any: model = await lookup_model(self, model_id) + if model.provider_id not in self.impls_by_provider_id: + raise ValueError(f"Provider {model.provider_id} not found in the routing table") return self.impls_by_provider_id[model.provider_id] async def register_model( diff --git a/llama_stack/core/routing_tables/toolgroups.py b/llama_stack/core/routing_tables/toolgroups.py index e172af991..6910b3906 100644 --- a/llama_stack/core/routing_tables/toolgroups.py +++ b/llama_stack/core/routing_tables/toolgroups.py @@ -124,10 +124,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): return toolgroup async def unregister_toolgroup(self, toolgroup_id: str) -> None: - tool_group = await self.get_tool_group(toolgroup_id) - if tool_group is None: - raise ToolGroupNotFoundError(toolgroup_id) - await self.unregister_object(tool_group) + await self.unregister_object(await self.get_tool_group(toolgroup_id)) async def shutdown(self) -> None: pass diff --git a/llama_stack/core/routing_tables/vector_dbs.py b/llama_stack/core/routing_tables/vector_dbs.py index c81a27a3b..e8dc46997 100644 --- a/llama_stack/core/routing_tables/vector_dbs.py +++ b/llama_stack/core/routing_tables/vector_dbs.py @@ -8,7 +8,7 @@ from typing import Any from pydantic import TypeAdapter -from llama_stack.apis.common.errors import ModelNotFoundError, VectorStoreNotFoundError +from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError, VectorStoreNotFoundError from llama_stack.apis.models import ModelType from llama_stack.apis.resource import ResourceType from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs @@ -66,7 +66,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): if model is None: raise ModelNotFoundError(embedding_model) if model.model_type != ModelType.embedding: - raise ValueError(f"Model {embedding_model} is not an embedding model") + raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding) if "embedding_dimension" not in model.metadata: raise ValueError(f"Model {embedding_model} does not have an embedding dimension") vector_db_data = { diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index fe5cc68d7..e9d70fc8d 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -21,10 +21,11 @@ from importlib.metadata import version as parse_version from pathlib import Path from typing import Annotated, Any, get_origin +import httpx import rich.pretty import yaml from aiohttp import hdrs -from fastapi import Body, FastAPI, HTTPException, Request +from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi import Path as FastapiPath from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse @@ -115,7 +116,7 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro if isinstance(exc, RequestValidationError): return HTTPException( - status_code=400, + status_code=httpx.codes.BAD_REQUEST, detail={ "errors": [ { @@ -128,20 +129,20 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro }, ) elif isinstance(exc, ValueError): - return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}") + return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}") elif isinstance(exc, BadRequestError): - return HTTPException(status_code=400, detail=str(exc)) + return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=str(exc)) elif isinstance(exc, PermissionError | AccessDeniedError): - return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}") + return HTTPException(status_code=httpx.codes.FORBIDDEN, detail=f"Permission denied: {str(exc)}") elif isinstance(exc, asyncio.TimeoutError | TimeoutError): - return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}") + return HTTPException(status_code=httpx.codes.GATEWAY_TIMEOUT, detail=f"Operation timed out: {str(exc)}") elif isinstance(exc, NotImplementedError): - return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}") + return HTTPException(status_code=httpx.codes.NOT_IMPLEMENTED, detail=f"Not implemented: {str(exc)}") elif isinstance(exc, AuthenticationRequiredError): - return HTTPException(status_code=401, detail=f"Authentication required: {str(exc)}") + return HTTPException(status_code=httpx.codes.UNAUTHORIZED, detail=f"Authentication required: {str(exc)}") else: return HTTPException( - status_code=500, + status_code=httpx.codes.INTERNAL_SERVER_ERROR, detail="Internal server error: An unexpected error occurred.", ) @@ -180,7 +181,6 @@ async def sse_generator(event_gen_coroutine): event_gen = await event_gen_coroutine async for item in event_gen: yield create_sse_event(item) - await asyncio.sleep(0.01) except asyncio.CancelledError: logger.info("Generator cancelled") if event_gen: @@ -236,6 +236,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: result = await maybe_await(value) if isinstance(result, PaginatedResponse) and result.url is None: result.url = route + + if method.upper() == "DELETE" and result is None: + return Response(status_code=httpx.codes.NO_CONTENT) + return result except Exception as e: if logger.isEnabledFor(logging.DEBUG): @@ -352,7 +356,7 @@ class ClientVersionMiddleware: await send( { "type": "http.response.start", - "status": 426, + "status": httpx.codes.UPGRADE_REQUIRED, "headers": [[b"content-type", b"application/json"]], } ) diff --git a/llama_stack/distributions/ci-tests/build.yaml b/llama_stack/distributions/ci-tests/build.yaml index 2f9ae8682..e6e699b62 100644 --- a/llama_stack/distributions/ci-tests/build.yaml +++ b/llama_stack/distributions/ci-tests/build.yaml @@ -14,6 +14,7 @@ distribution_spec: - provider_type: remote::openai - provider_type: remote::anthropic - provider_type: remote::gemini + - provider_type: remote::vertexai - provider_type: remote::groq - provider_type: remote::sambanova - provider_type: inline::sentence-transformers diff --git a/llama_stack/distributions/ci-tests/run.yaml b/llama_stack/distributions/ci-tests/run.yaml index 188c66275..05e1b4576 100644 --- a/llama_stack/distributions/ci-tests/run.yaml +++ b/llama_stack/distributions/ci-tests/run.yaml @@ -65,6 +65,11 @@ providers: provider_type: remote::gemini config: api_key: ${env.GEMINI_API_KEY:=} + - provider_id: ${env.VERTEX_AI_PROJECT:+vertexai} + provider_type: remote::vertexai + config: + project: ${env.VERTEX_AI_PROJECT:=} + location: ${env.VERTEX_AI_LOCATION:=us-central1} - provider_id: groq provider_type: remote::groq config: diff --git a/llama_stack/distributions/dell/dell.py b/llama_stack/distributions/dell/dell.py index b561ea00e..e3bf0ee03 100644 --- a/llama_stack/distributions/dell/dell.py +++ b/llama_stack/distributions/dell/dell.py @@ -16,6 +16,7 @@ from llama_stack.distributions.template import DistributionTemplate, RunConfigSe from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) +from llama_stack.providers.remote.vector_io.chroma import ChromaVectorIOConfig def get_distribution_template() -> DistributionTemplate: @@ -71,9 +72,10 @@ def get_distribution_template() -> DistributionTemplate: chromadb_provider = Provider( provider_id="chromadb", provider_type="remote::chromadb", - config={ - "url": "${env.CHROMA_URL}", - }, + config=ChromaVectorIOConfig.sample_run_config( + f"~/.llama/distributions/{name}/", + url="${env.CHROMADB_URL:=}", + ), ) inference_model = ModelInput( diff --git a/llama_stack/distributions/dell/run-with-safety.yaml b/llama_stack/distributions/dell/run-with-safety.yaml index ecc6729eb..d89c92aa1 100644 --- a/llama_stack/distributions/dell/run-with-safety.yaml +++ b/llama_stack/distributions/dell/run-with-safety.yaml @@ -26,7 +26,10 @@ providers: - provider_id: chromadb provider_type: remote::chromadb config: - url: ${env.CHROMA_URL} + url: ${env.CHROMADB_URL:=} + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell/}/chroma_remote_registry.db safety: - provider_id: llama-guard provider_type: inline::llama-guard diff --git a/llama_stack/distributions/dell/run.yaml b/llama_stack/distributions/dell/run.yaml index fc2553526..7397410ba 100644 --- a/llama_stack/distributions/dell/run.yaml +++ b/llama_stack/distributions/dell/run.yaml @@ -22,7 +22,10 @@ providers: - provider_id: chromadb provider_type: remote::chromadb config: - url: ${env.CHROMA_URL} + url: ${env.CHROMADB_URL:=} + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell/}/chroma_remote_registry.db safety: - provider_id: llama-guard provider_type: inline::llama-guard diff --git a/llama_stack/distributions/starter/build.yaml b/llama_stack/distributions/starter/build.yaml index f95a03a9e..1a4f81d49 100644 --- a/llama_stack/distributions/starter/build.yaml +++ b/llama_stack/distributions/starter/build.yaml @@ -14,6 +14,7 @@ distribution_spec: - provider_type: remote::openai - provider_type: remote::anthropic - provider_type: remote::gemini + - provider_type: remote::vertexai - provider_type: remote::groq - provider_type: remote::sambanova - provider_type: inline::sentence-transformers diff --git a/llama_stack/distributions/starter/run.yaml b/llama_stack/distributions/starter/run.yaml index 8bd737686..46bd12956 100644 --- a/llama_stack/distributions/starter/run.yaml +++ b/llama_stack/distributions/starter/run.yaml @@ -65,6 +65,11 @@ providers: provider_type: remote::gemini config: api_key: ${env.GEMINI_API_KEY:=} + - provider_id: ${env.VERTEX_AI_PROJECT:+vertexai} + provider_type: remote::vertexai + config: + project: ${env.VERTEX_AI_PROJECT:=} + location: ${env.VERTEX_AI_LOCATION:=us-central1} - provider_id: groq provider_type: remote::groq config: diff --git a/llama_stack/distributions/starter/starter.py b/llama_stack/distributions/starter/starter.py index a970f2d1c..0270b68ad 100644 --- a/llama_stack/distributions/starter/starter.py +++ b/llama_stack/distributions/starter/starter.py @@ -56,6 +56,7 @@ ENABLED_INFERENCE_PROVIDERS = [ "fireworks", "together", "gemini", + "vertexai", "groq", "sambanova", "anthropic", @@ -71,6 +72,7 @@ INFERENCE_PROVIDER_IDS = { "tgi": "${env.TGI_URL:+tgi}", "cerebras": "${env.CEREBRAS_API_KEY:+cerebras}", "nvidia": "${env.NVIDIA_API_KEY:+nvidia}", + "vertexai": "${env.VERTEX_AI_PROJECT:+vertexai}", } @@ -246,6 +248,14 @@ def get_distribution_template() -> DistributionTemplate: "", "Gemini API Key", ), + "VERTEX_AI_PROJECT": ( + "", + "Google Cloud Project ID for Vertex AI", + ), + "VERTEX_AI_LOCATION": ( + "us-central1", + "Google Cloud Location for Vertex AI", + ), "SAMBANOVA_API_KEY": ( "", "SambaNova API Key", diff --git a/llama_stack/log.py b/llama_stack/log.py index ab53e08c0..7507aface 100644 --- a/llama_stack/log.py +++ b/llama_stack/log.py @@ -32,6 +32,7 @@ CATEGORIES = [ "tools", "client", "telemetry", + "openai_responses", ] # Initialize category levels with default level @@ -99,7 +100,8 @@ def parse_environment_config(env_config: str) -> dict[str, int]: Dict[str, int]: A dictionary mapping categories to their log levels. """ category_levels = {} - for pair in env_config.split(";"): + delimiter = "," + for pair in env_config.split(delimiter): if not pair.strip(): continue diff --git a/llama_stack/models/llama/llama3/chat_format.py b/llama_stack/models/llama/llama3/chat_format.py index 0a973cf0c..1f88a1699 100644 --- a/llama_stack/models/llama/llama3/chat_format.py +++ b/llama_stack/models/llama/llama3/chat_format.py @@ -236,6 +236,7 @@ class ChatFormat: arguments_json=json.dumps(tool_arguments), ) ) + content = "" return RawMessage( role="assistant", diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 5f7c90879..e9f89f8d2 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -68,6 +68,11 @@ from llama_stack.models.llama.datatypes import ( BuiltinTool, ToolCall, ) +from llama_stack.providers.utils.inference.openai_compat import ( + convert_message_to_openai_dict, + convert_openai_chat_completion_stream, + convert_tooldef_to_openai_tool, +) from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.telemetry import tracing @@ -510,16 +515,60 @@ class ChatAgent(ShieldRunnerMixin): async with tracing.span("inference") as span: if self.agent_config.name: span.set_attribute("agent_name", self.agent_config.name) - async for chunk in await self.inference_api.chat_completion( - self.agent_config.model, - input_messages, - tools=self.tool_defs, - tool_prompt_format=self.agent_config.tool_config.tool_prompt_format, - response_format=self.agent_config.response_format, + # Convert messages to OpenAI format + openai_messages = [] + for message in input_messages: + openai_message = await convert_message_to_openai_dict(message) + openai_messages.append(openai_message) + + # Convert tool definitions to OpenAI format + openai_tools = None + if self.tool_defs: + openai_tools = [] + for tool_def in self.tool_defs: + openai_tool = convert_tooldef_to_openai_tool(tool_def) + openai_tools.append(openai_tool) + + # Extract tool_choice from tool_config for OpenAI compatibility + # Note: tool_choice can only be provided when tools are also provided + tool_choice = None + if openai_tools and self.agent_config.tool_config and self.agent_config.tool_config.tool_choice: + tool_choice = ( + self.agent_config.tool_config.tool_choice.value + if hasattr(self.agent_config.tool_config.tool_choice, "value") + else str(self.agent_config.tool_config.tool_choice) + ) + + # Convert sampling params to OpenAI format (temperature, top_p, max_tokens) + temperature = None + top_p = None + max_tokens = None + if sampling_params: + if hasattr(sampling_params.strategy, "temperature"): + temperature = sampling_params.strategy.temperature + if hasattr(sampling_params.strategy, "top_p"): + top_p = sampling_params.strategy.top_p + if sampling_params.max_tokens: + max_tokens = sampling_params.max_tokens + + # Use OpenAI chat completion + openai_stream = await self.inference_api.openai_chat_completion( + model=self.agent_config.model, + messages=openai_messages, + tools=openai_tools if openai_tools else None, + tool_choice=tool_choice, + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, stream=True, - sampling_params=sampling_params, - tool_config=self.agent_config.tool_config, - ): + ) + + # Convert OpenAI stream back to Llama Stack format + response_stream = convert_openai_chat_completion_stream( + openai_stream, enable_incremental_tool_calls=True + ) + + async for chunk in response_stream: event = chunk.event if event.event_type == ChatCompletionResponseEventType.start: continue diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 15695ec48..30196c429 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -48,8 +48,8 @@ from llama_stack.providers.utils.responses.responses_store import ResponsesStore from .agent_instance import ChatAgent from .config import MetaReferenceAgentsImplConfig -from .openai_responses import OpenAIResponsesImpl from .persistence import AgentInfo +from .responses.openai_responses import OpenAIResponsesImpl logger = logging.getLogger() @@ -327,10 +327,21 @@ class MetaReferenceAgentsImpl(Agents): temperature: float | None = None, text: OpenAIResponseText | None = None, tools: list[OpenAIResponseInputTool] | None = None, + include: list[str] | None = None, max_infer_iters: int | None = 10, ) -> OpenAIResponseObject: return await self.openai_responses_impl.create_openai_response( - input, model, instructions, previous_response_id, store, stream, temperature, text, tools, max_infer_iters + input, + model, + instructions, + previous_response_id, + store, + stream, + temperature, + text, + tools, + include, + max_infer_iters, ) async def list_openai_responses( diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py deleted file mode 100644 index 7eb2b3897..000000000 --- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py +++ /dev/null @@ -1,880 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import asyncio -import json -import time -import uuid -from collections.abc import AsyncIterator -from typing import Any - -from openai.types.chat import ChatCompletionToolParam -from pydantic import BaseModel - -from llama_stack.apis.agents import Order -from llama_stack.apis.agents.openai_responses import ( - AllowedToolsFilter, - ListOpenAIResponseInputItem, - ListOpenAIResponseObject, - OpenAIDeleteResponseObject, - OpenAIResponseInput, - OpenAIResponseInputFunctionToolCallOutput, - OpenAIResponseInputMessageContent, - OpenAIResponseInputMessageContentImage, - OpenAIResponseInputMessageContentText, - OpenAIResponseInputTool, - OpenAIResponseInputToolFileSearch, - OpenAIResponseInputToolMCP, - OpenAIResponseMessage, - OpenAIResponseObject, - OpenAIResponseObjectStream, - OpenAIResponseObjectStreamResponseCompleted, - OpenAIResponseObjectStreamResponseCreated, - OpenAIResponseObjectStreamResponseOutputTextDelta, - OpenAIResponseOutput, - OpenAIResponseOutputMessageContent, - OpenAIResponseOutputMessageContentOutputText, - OpenAIResponseOutputMessageFileSearchToolCall, - OpenAIResponseOutputMessageFunctionToolCall, - OpenAIResponseOutputMessageMCPListTools, - OpenAIResponseOutputMessageWebSearchToolCall, - OpenAIResponseText, - OpenAIResponseTextFormat, - WebSearchToolTypes, -) -from llama_stack.apis.common.content_types import TextContentItem -from llama_stack.apis.inference import ( - Inference, - OpenAIAssistantMessageParam, - OpenAIChatCompletion, - OpenAIChatCompletionContentPartImageParam, - OpenAIChatCompletionContentPartParam, - OpenAIChatCompletionContentPartTextParam, - OpenAIChatCompletionToolCall, - OpenAIChatCompletionToolCallFunction, - OpenAIChoice, - OpenAIDeveloperMessageParam, - OpenAIImageURL, - OpenAIJSONSchema, - OpenAIMessageParam, - OpenAIResponseFormatJSONObject, - OpenAIResponseFormatJSONSchema, - OpenAIResponseFormatParam, - OpenAIResponseFormatText, - OpenAISystemMessageParam, - OpenAIToolMessageParam, - OpenAIUserMessageParam, -) -from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime -from llama_stack.apis.vector_io import VectorIO -from llama_stack.log import get_logger -from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition -from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool -from llama_stack.providers.utils.responses.responses_store import ResponsesStore - -logger = get_logger(name=__name__, category="openai_responses") - -OPENAI_RESPONSES_PREFIX = "openai_responses:" - - -async def _convert_response_content_to_chat_content( - content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent], -) -> str | list[OpenAIChatCompletionContentPartParam]: - """ - Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts. - - The content schemas of each API look similar, but are not exactly the same. - """ - if isinstance(content, str): - return content - - converted_parts = [] - for content_part in content: - if isinstance(content_part, OpenAIResponseInputMessageContentText): - converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text)) - elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText): - converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text)) - elif isinstance(content_part, OpenAIResponseInputMessageContentImage): - if content_part.image_url: - image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail) - converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url)) - elif isinstance(content_part, str): - converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part)) - else: - raise ValueError( - f"Llama Stack OpenAI Responses does not yet support content type '{type(content_part)}' in this context" - ) - return converted_parts - - -async def _convert_response_input_to_chat_messages( - input: str | list[OpenAIResponseInput], -) -> list[OpenAIMessageParam]: - """ - Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages. - """ - messages: list[OpenAIMessageParam] = [] - if isinstance(input, list): - for input_item in input: - if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput): - messages.append( - OpenAIToolMessageParam( - content=input_item.output, - tool_call_id=input_item.call_id, - ) - ) - elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall): - tool_call = OpenAIChatCompletionToolCall( - index=0, - id=input_item.call_id, - function=OpenAIChatCompletionToolCallFunction( - name=input_item.name, - arguments=input_item.arguments, - ), - ) - messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call])) - else: - content = await _convert_response_content_to_chat_content(input_item.content) - message_type = await _get_message_type_by_role(input_item.role) - if message_type is None: - raise ValueError( - f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context" - ) - messages.append(message_type(content=content)) - else: - messages.append(OpenAIUserMessageParam(content=input)) - return messages - - -async def _convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage: - """ - Convert an OpenAI Chat Completion choice into an OpenAI Response output message. - """ - output_content = "" - if isinstance(choice.message.content, str): - output_content = choice.message.content - elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam): - output_content = choice.message.content.text - else: - raise ValueError( - f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}" - ) - - return OpenAIResponseMessage( - id=f"msg_{uuid.uuid4()}", - content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)], - status="completed", - role="assistant", - ) - - -async def _convert_response_text_to_chat_response_format(text: OpenAIResponseText) -> OpenAIResponseFormatParam: - """ - Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format. - """ - if not text.format or text.format["type"] == "text": - return OpenAIResponseFormatText(type="text") - if text.format["type"] == "json_object": - return OpenAIResponseFormatJSONObject() - if text.format["type"] == "json_schema": - return OpenAIResponseFormatJSONSchema( - json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"]) - ) - raise ValueError(f"Unsupported text format: {text.format}") - - -async def _get_message_type_by_role(role: str): - role_to_type = { - "user": OpenAIUserMessageParam, - "system": OpenAISystemMessageParam, - "assistant": OpenAIAssistantMessageParam, - "developer": OpenAIDeveloperMessageParam, - } - return role_to_type.get(role) - - -class OpenAIResponsePreviousResponseWithInputItems(BaseModel): - input_items: ListOpenAIResponseInputItem - response: OpenAIResponseObject - - -class ChatCompletionContext(BaseModel): - model: str - messages: list[OpenAIMessageParam] - response_tools: list[OpenAIResponseInputTool] | None = None - chat_tools: list[ChatCompletionToolParam] | None = None - mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] - temperature: float | None - response_format: OpenAIResponseFormatParam - - -class OpenAIResponsesImpl: - def __init__( - self, - inference_api: Inference, - tool_groups_api: ToolGroups, - tool_runtime_api: ToolRuntime, - responses_store: ResponsesStore, - vector_io_api: VectorIO, # VectorIO - ): - self.inference_api = inference_api - self.tool_groups_api = tool_groups_api - self.tool_runtime_api = tool_runtime_api - self.responses_store = responses_store - self.vector_io_api = vector_io_api - - async def _prepend_previous_response( - self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None - ): - if previous_response_id: - previous_response_with_input = await self.responses_store.get_response_object(previous_response_id) - - # previous response input items - new_input_items = previous_response_with_input.input - - # previous response output items - new_input_items.extend(previous_response_with_input.output) - - # new input items from the current request - if isinstance(input, str): - new_input_items.append(OpenAIResponseMessage(content=input, role="user")) - else: - new_input_items.extend(input) - - input = new_input_items - - return input - - async def _prepend_instructions(self, messages, instructions): - if instructions: - messages.insert(0, OpenAISystemMessageParam(content=instructions)) - - async def get_openai_response( - self, - response_id: str, - ) -> OpenAIResponseObject: - response_with_input = await self.responses_store.get_response_object(response_id) - return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"}) - - async def list_openai_responses( - self, - after: str | None = None, - limit: int | None = 50, - model: str | None = None, - order: Order | None = Order.desc, - ) -> ListOpenAIResponseObject: - return await self.responses_store.list_responses(after, limit, model, order) - - async def list_openai_response_input_items( - self, - response_id: str, - after: str | None = None, - before: str | None = None, - include: list[str] | None = None, - limit: int | None = 20, - order: Order | None = Order.desc, - ) -> ListOpenAIResponseInputItem: - """List input items for a given OpenAI response. - - :param response_id: The ID of the response to retrieve input items for. - :param after: An item ID to list items after, used for pagination. - :param before: An item ID to list items before, used for pagination. - :param include: Additional fields to include in the response. - :param limit: A limit on the number of objects to be returned. - :param order: The order to return the input items in. - :returns: An ListOpenAIResponseInputItem. - """ - return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order) - - async def _store_response( - self, - response: OpenAIResponseObject, - input: str | list[OpenAIResponseInput], - ) -> None: - new_input_id = f"msg_{uuid.uuid4()}" - if isinstance(input, str): - # synthesize a message from the input string - input_content = OpenAIResponseInputMessageContentText(text=input) - input_content_item = OpenAIResponseMessage( - role="user", - content=[input_content], - id=new_input_id, - ) - input_items_data = [input_content_item] - else: - # we already have a list of messages - input_items_data = [] - for input_item in input: - if isinstance(input_item, OpenAIResponseMessage): - # These may or may not already have an id, so dump to dict, check for id, and add if missing - input_item_dict = input_item.model_dump() - if "id" not in input_item_dict: - input_item_dict["id"] = new_input_id - input_items_data.append(OpenAIResponseMessage(**input_item_dict)) - else: - input_items_data.append(input_item) - - await self.responses_store.store_response_object( - response_object=response, - input=input_items_data, - ) - - async def create_openai_response( - self, - input: str | list[OpenAIResponseInput], - model: str, - instructions: str | None = None, - previous_response_id: str | None = None, - store: bool | None = True, - stream: bool | None = False, - temperature: float | None = None, - text: OpenAIResponseText | None = None, - tools: list[OpenAIResponseInputTool] | None = None, - max_infer_iters: int | None = 10, - ): - stream = bool(stream) - text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text - - stream_gen = self._create_streaming_response( - input=input, - model=model, - instructions=instructions, - previous_response_id=previous_response_id, - store=store, - temperature=temperature, - text=text, - tools=tools, - max_infer_iters=max_infer_iters, - ) - - if stream: - return stream_gen - else: - response = None - async for stream_chunk in stream_gen: - if stream_chunk.type == "response.completed": - if response is not None: - raise ValueError("The response stream completed multiple times! Earlier response: {response}") - response = stream_chunk.response - # don't leave the generator half complete! - - if response is None: - raise ValueError("The response stream never completed") - return response - - async def _create_streaming_response( - self, - input: str | list[OpenAIResponseInput], - model: str, - instructions: str | None = None, - previous_response_id: str | None = None, - store: bool | None = True, - temperature: float | None = None, - text: OpenAIResponseText | None = None, - tools: list[OpenAIResponseInputTool] | None = None, - max_infer_iters: int | None = 10, - ) -> AsyncIterator[OpenAIResponseObjectStream]: - output_messages: list[OpenAIResponseOutput] = [] - - # Input preprocessing - input = await self._prepend_previous_response(input, previous_response_id) - messages = await _convert_response_input_to_chat_messages(input) - await self._prepend_instructions(messages, instructions) - - # Structured outputs - response_format = await _convert_response_text_to_chat_response_format(text) - - # Tool setup, TODO: refactor this slightly since this can also yield events - chat_tools, mcp_tool_to_server, mcp_list_message = ( - await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None) - ) - if mcp_list_message: - output_messages.append(mcp_list_message) - - ctx = ChatCompletionContext( - model=model, - messages=messages, - response_tools=tools, - chat_tools=chat_tools, - mcp_tool_to_server=mcp_tool_to_server, - temperature=temperature, - response_format=response_format, - ) - - # Create initial response and emit response.created immediately - response_id = f"resp-{uuid.uuid4()}" - created_at = int(time.time()) - - initial_response = OpenAIResponseObject( - created_at=created_at, - id=response_id, - model=model, - object="response", - status="in_progress", - output=output_messages.copy(), - text=text, - ) - - yield OpenAIResponseObjectStreamResponseCreated(response=initial_response) - - n_iter = 0 - messages = ctx.messages.copy() - - while True: - completion_result = await self.inference_api.openai_chat_completion( - model=ctx.model, - messages=messages, - tools=ctx.chat_tools, - stream=True, - temperature=ctx.temperature, - response_format=ctx.response_format, - ) - - # Process streaming chunks and build complete response - chat_response_id = "" - chat_response_content = [] - chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {} - chunk_created = 0 - chunk_model = "" - chunk_finish_reason = "" - sequence_number = 0 - - # Create a placeholder message item for delta events - message_item_id = f"msg_{uuid.uuid4()}" - - async for chunk in completion_result: - chat_response_id = chunk.id - chunk_created = chunk.created - chunk_model = chunk.model - for chunk_choice in chunk.choices: - # Emit incremental text content as delta events - if chunk_choice.delta.content: - sequence_number += 1 - yield OpenAIResponseObjectStreamResponseOutputTextDelta( - content_index=0, - delta=chunk_choice.delta.content, - item_id=message_item_id, - output_index=0, - sequence_number=sequence_number, - ) - - # Collect content for final response - chat_response_content.append(chunk_choice.delta.content or "") - if chunk_choice.finish_reason: - chunk_finish_reason = chunk_choice.finish_reason - - # Aggregate tool call arguments across chunks - if chunk_choice.delta.tool_calls: - for tool_call in chunk_choice.delta.tool_calls: - response_tool_call = chat_response_tool_calls.get(tool_call.index, None) - if response_tool_call: - # Don't attempt to concatenate arguments if we don't have any new argumentsAdd commentMore actions - if tool_call.function.arguments: - # Guard against an initial None argument before we concatenate - response_tool_call.function.arguments = ( - response_tool_call.function.arguments or "" - ) + tool_call.function.arguments - else: - tool_call_dict: dict[str, Any] = tool_call.model_dump() - tool_call_dict.pop("type", None) - response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict) - chat_response_tool_calls[tool_call.index] = response_tool_call - - # Convert collected chunks to complete response - if chat_response_tool_calls: - tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())] - else: - tool_calls = None - assistant_message = OpenAIAssistantMessageParam( - content="".join(chat_response_content), - tool_calls=tool_calls, - ) - current_response = OpenAIChatCompletion( - id=chat_response_id, - choices=[ - OpenAIChoice( - message=assistant_message, - finish_reason=chunk_finish_reason, - index=0, - ) - ], - created=chunk_created, - model=chunk_model, - ) - - function_tool_calls = [] - non_function_tool_calls = [] - - next_turn_messages = messages.copy() - for choice in current_response.choices: - next_turn_messages.append(choice.message) - - if choice.message.tool_calls and tools: - for tool_call in choice.message.tool_calls: - if _is_function_tool_call(tool_call, tools): - function_tool_calls.append(tool_call) - else: - non_function_tool_calls.append(tool_call) - else: - output_messages.append(await _convert_chat_choice_to_response_message(choice)) - - # execute non-function tool calls - for tool_call in non_function_tool_calls: - tool_call_log, tool_response_message = await self._execute_tool_call(tool_call, ctx) - if tool_call_log: - output_messages.append(tool_call_log) - if tool_response_message: - next_turn_messages.append(tool_response_message) - - for tool_call in function_tool_calls: - output_messages.append( - OpenAIResponseOutputMessageFunctionToolCall( - arguments=tool_call.function.arguments or "", - call_id=tool_call.id, - name=tool_call.function.name or "", - id=f"fc_{uuid.uuid4()}", - status="completed", - ) - ) - - if not function_tool_calls and not non_function_tool_calls: - break - - if function_tool_calls: - logger.info("Exiting inference loop since there is a function (client-side) tool call") - break - - n_iter += 1 - if n_iter >= max_infer_iters: - logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {max_infer_iters=}") - break - - messages = next_turn_messages - - # Create final response - final_response = OpenAIResponseObject( - created_at=created_at, - id=response_id, - model=model, - object="response", - status="completed", - text=text, - output=output_messages, - ) - - # Emit response.completed - yield OpenAIResponseObjectStreamResponseCompleted(response=final_response) - - if store: - await self._store_response( - response=final_response, - input=input, - ) - - async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject: - return await self.responses_store.delete_response_object(response_id) - - async def _convert_response_tools_to_chat_tools( - self, tools: list[OpenAIResponseInputTool] - ) -> tuple[ - list[ChatCompletionToolParam], - dict[str, OpenAIResponseInputToolMCP], - OpenAIResponseOutput | None, - ]: - from llama_stack.apis.agents.openai_responses import ( - MCPListToolsTool, - ) - from llama_stack.apis.tools import Tool - - mcp_tool_to_server = {} - - def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam: - tool_def = ToolDefinition( - tool_name=tool_name, - description=tool.description, - parameters={ - param.name: ToolParamDefinition( - param_type=param.parameter_type, - description=param.description, - required=param.required, - default=param.default, - ) - for param in tool.parameters - }, - ) - return convert_tooldef_to_openai_tool(tool_def) - - mcp_list_message = None - chat_tools: list[ChatCompletionToolParam] = [] - for input_tool in tools: - # TODO: Handle other tool types - if input_tool.type == "function": - chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump())) - elif input_tool.type in WebSearchToolTypes: - tool_name = "web_search" - tool = await self.tool_groups_api.get_tool(tool_name) - if not tool: - raise ValueError(f"Tool {tool_name} not found") - chat_tools.append(make_openai_tool(tool_name, tool)) - elif input_tool.type == "file_search": - tool_name = "knowledge_search" - tool = await self.tool_groups_api.get_tool(tool_name) - if not tool: - raise ValueError(f"Tool {tool_name} not found") - chat_tools.append(make_openai_tool(tool_name, tool)) - elif input_tool.type == "mcp": - from llama_stack.providers.utils.tools.mcp import list_mcp_tools - - always_allowed = None - never_allowed = None - if input_tool.allowed_tools: - if isinstance(input_tool.allowed_tools, list): - always_allowed = input_tool.allowed_tools - elif isinstance(input_tool.allowed_tools, AllowedToolsFilter): - always_allowed = input_tool.allowed_tools.always - never_allowed = input_tool.allowed_tools.never - - tool_defs = await list_mcp_tools( - endpoint=input_tool.server_url, - headers=input_tool.headers or {}, - ) - - mcp_list_message = OpenAIResponseOutputMessageMCPListTools( - id=f"mcp_list_{uuid.uuid4()}", - status="completed", - server_label=input_tool.server_label, - tools=[], - ) - for t in tool_defs.data: - if never_allowed and t.name in never_allowed: - continue - if not always_allowed or t.name in always_allowed: - chat_tools.append(make_openai_tool(t.name, t)) - if t.name in mcp_tool_to_server: - raise ValueError(f"Duplicate tool name {t.name} found for server {input_tool.server_label}") - mcp_tool_to_server[t.name] = input_tool - mcp_list_message.tools.append( - MCPListToolsTool( - name=t.name, - description=t.description, - input_schema={ - "type": "object", - "properties": { - p.name: { - "type": p.parameter_type, - "description": p.description, - } - for p in t.parameters - }, - "required": [p.name for p in t.parameters if p.required], - }, - ) - ) - else: - raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}") - return chat_tools, mcp_tool_to_server, mcp_list_message - - async def _execute_knowledge_search_via_vector_store( - self, - query: str, - response_file_search_tool: OpenAIResponseInputToolFileSearch, - ) -> ToolInvocationResult: - """Execute knowledge search using vector_stores.search API with filters support.""" - search_results = [] - - # Create search tasks for all vector stores - async def search_single_store(vector_store_id): - try: - search_response = await self.vector_io_api.openai_search_vector_store( - vector_store_id=vector_store_id, - query=query, - filters=response_file_search_tool.filters, - max_num_results=response_file_search_tool.max_num_results, - ranking_options=response_file_search_tool.ranking_options, - rewrite_query=False, - ) - return search_response.data - except Exception as e: - logger.warning(f"Failed to search vector store {vector_store_id}: {e}") - return [] - - # Run all searches in parallel using gather - search_tasks = [search_single_store(vid) for vid in response_file_search_tool.vector_store_ids] - all_results = await asyncio.gather(*search_tasks) - - # Flatten results - for results in all_results: - search_results.extend(results) - - # Convert search results to tool result format matching memory.py - # Format the results as interleaved content similar to memory.py - content_items = [] - content_items.append( - TextContentItem( - text=f"knowledge_search tool found {len(search_results)} chunks:\nBEGIN of knowledge_search tool results.\n" - ) - ) - - for i, result_item in enumerate(search_results): - chunk_text = result_item.content[0].text if result_item.content else "" - metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}" - if result_item.attributes: - metadata_text += f", attributes: {result_item.attributes}" - text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n" - content_items.append(TextContentItem(text=text_content)) - - content_items.append(TextContentItem(text="END of knowledge_search tool results.\n")) - content_items.append( - TextContentItem( - text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n', - ) - ) - - return ToolInvocationResult( - content=content_items, - metadata={ - "document_ids": [r.file_id for r in search_results], - "chunks": [r.content[0].text if r.content else "" for r in search_results], - "scores": [r.score for r in search_results], - }, - ) - - async def _execute_tool_call( - self, - tool_call: OpenAIChatCompletionToolCall, - ctx: ChatCompletionContext, - ) -> tuple[OpenAIResponseOutput | None, OpenAIMessageParam | None]: - from llama_stack.providers.utils.inference.prompt_adapter import ( - interleaved_content_as_str, - ) - - tool_call_id = tool_call.id - function = tool_call.function - tool_kwargs = json.loads(function.arguments) if function.arguments else {} - - if not function or not tool_call_id or not function.name: - return None, None - - error_exc = None - result = None - try: - if ctx.mcp_tool_to_server and function.name in ctx.mcp_tool_to_server: - from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool - - mcp_tool = ctx.mcp_tool_to_server[function.name] - result = await invoke_mcp_tool( - endpoint=mcp_tool.server_url, - headers=mcp_tool.headers or {}, - tool_name=function.name, - kwargs=tool_kwargs, - ) - elif function.name == "knowledge_search": - response_file_search_tool = next( - (t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), None - ) - if response_file_search_tool: - # Use vector_stores.search API instead of knowledge_search tool - # to support filters and ranking_options - query = tool_kwargs.get("query", "") - result = await self._execute_knowledge_search_via_vector_store( - query=query, - response_file_search_tool=response_file_search_tool, - ) - else: - result = await self.tool_runtime_api.invoke_tool( - tool_name=function.name, - kwargs=tool_kwargs, - ) - except Exception as e: - error_exc = e - - if function.name in ctx.mcp_tool_to_server: - from llama_stack.apis.agents.openai_responses import OpenAIResponseOutputMessageMCPCall - - message = OpenAIResponseOutputMessageMCPCall( - id=tool_call_id, - arguments=function.arguments, - name=function.name, - server_label=ctx.mcp_tool_to_server[function.name].server_label, - ) - if error_exc: - message.error = str(error_exc) - elif (result.error_code and result.error_code > 0) or result.error_message: - message.error = f"Error (code {result.error_code}): {result.error_message}" - elif result.content: - message.output = interleaved_content_as_str(result.content) - else: - if function.name == "web_search": - message = OpenAIResponseOutputMessageWebSearchToolCall( - id=tool_call_id, - status="completed", - ) - if error_exc or (result.error_code and result.error_code > 0) or result.error_message: - message.status = "failed" - elif function.name == "knowledge_search": - message = OpenAIResponseOutputMessageFileSearchToolCall( - id=tool_call_id, - queries=[tool_kwargs.get("query", "")], - status="completed", - ) - if "document_ids" in result.metadata: - message.results = [] - for i, doc_id in enumerate(result.metadata["document_ids"]): - text = result.metadata["chunks"][i] if "chunks" in result.metadata else None - score = result.metadata["scores"][i] if "scores" in result.metadata else None - message.results.append( - { - "file_id": doc_id, - "filename": doc_id, - "text": text, - "score": score, - } - ) - if error_exc or (result.error_code and result.error_code > 0) or result.error_message: - message.status = "failed" - else: - raise ValueError(f"Unknown tool {function.name} called") - - input_message = None - if result and result.content: - if isinstance(result.content, str): - content = result.content - elif isinstance(result.content, list): - from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem - - content = [] - for item in result.content: - if isinstance(item, TextContentItem): - part = OpenAIChatCompletionContentPartTextParam(text=item.text) - elif isinstance(item, ImageContentItem): - if item.image.data: - url = f"data:image;base64,{item.image.data}" - else: - url = item.image.url - part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url)) - else: - raise ValueError(f"Unknown result content type: {type(item)}") - content.append(part) - else: - raise ValueError(f"Unknown result content type: {type(result.content)}") - input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id) - else: - text = str(error_exc) - input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id) - - return message, input_message - - -def _is_function_tool_call( - tool_call: OpenAIChatCompletionToolCall, - tools: list[OpenAIResponseInputTool], -) -> bool: - if not tool_call.function: - return False - for t in tools: - if t.type == "function" and t.name == tool_call.function.name: - return True - return False diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 7a8d99b78..0b234d96c 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -191,7 +191,11 @@ class AgentPersistence: sessions = [] for value in values: try: - session_info = Session(**json.loads(value)) + data = json.loads(value) + if "turn_id" in data: + continue + + session_info = Session(**data) sessions.append(session_info) except Exception as e: log.error(f"Error parsing session info: {e}") diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/__init__.py b/llama_stack/providers/inline/agents/meta_reference/responses/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/inline/agents/meta_reference/responses/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py new file mode 100644 index 000000000..e528a4005 --- /dev/null +++ b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py @@ -0,0 +1,271 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import time +import uuid +from collections.abc import AsyncIterator + +from pydantic import BaseModel + +from llama_stack.apis.agents import Order +from llama_stack.apis.agents.openai_responses import ( + ListOpenAIResponseInputItem, + ListOpenAIResponseObject, + OpenAIDeleteResponseObject, + OpenAIResponseInput, + OpenAIResponseInputMessageContentText, + OpenAIResponseInputTool, + OpenAIResponseMessage, + OpenAIResponseObject, + OpenAIResponseObjectStream, + OpenAIResponseText, + OpenAIResponseTextFormat, +) +from llama_stack.apis.inference import ( + Inference, + OpenAISystemMessageParam, +) +from llama_stack.apis.tools import ToolGroups, ToolRuntime +from llama_stack.apis.vector_io import VectorIO +from llama_stack.log import get_logger +from llama_stack.providers.utils.responses.responses_store import ResponsesStore + +from .streaming import StreamingResponseOrchestrator +from .tool_executor import ToolExecutor +from .types import ChatCompletionContext +from .utils import ( + convert_response_input_to_chat_messages, + convert_response_text_to_chat_response_format, +) + +logger = get_logger(name=__name__, category="responses") + + +class OpenAIResponsePreviousResponseWithInputItems(BaseModel): + input_items: ListOpenAIResponseInputItem + response: OpenAIResponseObject + + +class OpenAIResponsesImpl: + def __init__( + self, + inference_api: Inference, + tool_groups_api: ToolGroups, + tool_runtime_api: ToolRuntime, + responses_store: ResponsesStore, + vector_io_api: VectorIO, # VectorIO + ): + self.inference_api = inference_api + self.tool_groups_api = tool_groups_api + self.tool_runtime_api = tool_runtime_api + self.responses_store = responses_store + self.vector_io_api = vector_io_api + self.tool_executor = ToolExecutor( + tool_groups_api=tool_groups_api, + tool_runtime_api=tool_runtime_api, + vector_io_api=vector_io_api, + ) + + async def _prepend_previous_response( + self, + input: str | list[OpenAIResponseInput], + previous_response_id: str | None = None, + ): + if previous_response_id: + previous_response_with_input = await self.responses_store.get_response_object(previous_response_id) + + # previous response input items + new_input_items = previous_response_with_input.input + + # previous response output items + new_input_items.extend(previous_response_with_input.output) + + # new input items from the current request + if isinstance(input, str): + new_input_items.append(OpenAIResponseMessage(content=input, role="user")) + else: + new_input_items.extend(input) + + input = new_input_items + + return input + + async def _prepend_instructions(self, messages, instructions): + if instructions: + messages.insert(0, OpenAISystemMessageParam(content=instructions)) + + async def get_openai_response( + self, + response_id: str, + ) -> OpenAIResponseObject: + response_with_input = await self.responses_store.get_response_object(response_id) + return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"}) + + async def list_openai_responses( + self, + after: str | None = None, + limit: int | None = 50, + model: str | None = None, + order: Order | None = Order.desc, + ) -> ListOpenAIResponseObject: + return await self.responses_store.list_responses(after, limit, model, order) + + async def list_openai_response_input_items( + self, + response_id: str, + after: str | None = None, + before: str | None = None, + include: list[str] | None = None, + limit: int | None = 20, + order: Order | None = Order.desc, + ) -> ListOpenAIResponseInputItem: + """List input items for a given OpenAI response. + + :param response_id: The ID of the response to retrieve input items for. + :param after: An item ID to list items after, used for pagination. + :param before: An item ID to list items before, used for pagination. + :param include: Additional fields to include in the response. + :param limit: A limit on the number of objects to be returned. + :param order: The order to return the input items in. + :returns: An ListOpenAIResponseInputItem. + """ + return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order) + + async def _store_response( + self, + response: OpenAIResponseObject, + input: str | list[OpenAIResponseInput], + ) -> None: + new_input_id = f"msg_{uuid.uuid4()}" + if isinstance(input, str): + # synthesize a message from the input string + input_content = OpenAIResponseInputMessageContentText(text=input) + input_content_item = OpenAIResponseMessage( + role="user", + content=[input_content], + id=new_input_id, + ) + input_items_data = [input_content_item] + else: + # we already have a list of messages + input_items_data = [] + for input_item in input: + if isinstance(input_item, OpenAIResponseMessage): + # These may or may not already have an id, so dump to dict, check for id, and add if missing + input_item_dict = input_item.model_dump() + if "id" not in input_item_dict: + input_item_dict["id"] = new_input_id + input_items_data.append(OpenAIResponseMessage(**input_item_dict)) + else: + input_items_data.append(input_item) + + await self.responses_store.store_response_object( + response_object=response, + input=input_items_data, + ) + + async def create_openai_response( + self, + input: str | list[OpenAIResponseInput], + model: str, + instructions: str | None = None, + previous_response_id: str | None = None, + store: bool | None = True, + stream: bool | None = False, + temperature: float | None = None, + text: OpenAIResponseText | None = None, + tools: list[OpenAIResponseInputTool] | None = None, + include: list[str] | None = None, + max_infer_iters: int | None = 10, + ): + stream = bool(stream) + text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text + + stream_gen = self._create_streaming_response( + input=input, + model=model, + instructions=instructions, + previous_response_id=previous_response_id, + store=store, + temperature=temperature, + text=text, + tools=tools, + max_infer_iters=max_infer_iters, + ) + + if stream: + return stream_gen + else: + response = None + async for stream_chunk in stream_gen: + if stream_chunk.type == "response.completed": + if response is not None: + raise ValueError("The response stream completed multiple times! Earlier response: {response}") + response = stream_chunk.response + # don't leave the generator half complete! + + if response is None: + raise ValueError("The response stream never completed") + return response + + async def _create_streaming_response( + self, + input: str | list[OpenAIResponseInput], + model: str, + instructions: str | None = None, + previous_response_id: str | None = None, + store: bool | None = True, + temperature: float | None = None, + text: OpenAIResponseText | None = None, + tools: list[OpenAIResponseInputTool] | None = None, + max_infer_iters: int | None = 10, + ) -> AsyncIterator[OpenAIResponseObjectStream]: + # Input preprocessing + input = await self._prepend_previous_response(input, previous_response_id) + messages = await convert_response_input_to_chat_messages(input) + await self._prepend_instructions(messages, instructions) + + # Structured outputs + response_format = await convert_response_text_to_chat_response_format(text) + + ctx = ChatCompletionContext( + model=model, + messages=messages, + response_tools=tools, + temperature=temperature, + response_format=response_format, + ) + + # Create orchestrator and delegate streaming logic + response_id = f"resp-{uuid.uuid4()}" + created_at = int(time.time()) + + orchestrator = StreamingResponseOrchestrator( + inference_api=self.inference_api, + ctx=ctx, + response_id=response_id, + created_at=created_at, + text=text, + max_infer_iters=max_infer_iters, + tool_executor=self.tool_executor, + ) + + # Stream the response + final_response = None + async for stream_chunk in orchestrator.create_response(): + if stream_chunk.type == "response.completed": + final_response = stream_chunk.response + yield stream_chunk + + # Store the response if requested + if store and final_response: + await self._store_response( + response=final_response, + input=input, + ) + + async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject: + return await self.responses_store.delete_response_object(response_id) diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py new file mode 100644 index 000000000..0879e978a --- /dev/null +++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -0,0 +1,634 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import uuid +from collections.abc import AsyncIterator +from typing import Any + +from llama_stack.apis.agents.openai_responses import ( + AllowedToolsFilter, + MCPListToolsTool, + OpenAIResponseContentPartOutputText, + OpenAIResponseInputTool, + OpenAIResponseInputToolMCP, + OpenAIResponseObject, + OpenAIResponseObjectStream, + OpenAIResponseObjectStreamResponseCompleted, + OpenAIResponseObjectStreamResponseContentPartAdded, + OpenAIResponseObjectStreamResponseContentPartDone, + OpenAIResponseObjectStreamResponseCreated, + OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta, + OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone, + OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta, + OpenAIResponseObjectStreamResponseMcpCallArgumentsDone, + OpenAIResponseObjectStreamResponseMcpListToolsCompleted, + OpenAIResponseObjectStreamResponseMcpListToolsInProgress, + OpenAIResponseObjectStreamResponseOutputItemAdded, + OpenAIResponseObjectStreamResponseOutputItemDone, + OpenAIResponseObjectStreamResponseOutputTextDelta, + OpenAIResponseOutput, + OpenAIResponseOutputMessageFunctionToolCall, + OpenAIResponseOutputMessageMCPListTools, + OpenAIResponseText, + WebSearchToolTypes, +) +from llama_stack.apis.inference import ( + Inference, + OpenAIAssistantMessageParam, + OpenAIChatCompletion, + OpenAIChatCompletionToolCall, + OpenAIChoice, +) +from llama_stack.log import get_logger + +from .types import ChatCompletionContext, ChatCompletionResult +from .utils import convert_chat_choice_to_response_message, is_function_tool_call + +logger = get_logger(name=__name__, category="responses") + + +class StreamingResponseOrchestrator: + def __init__( + self, + inference_api: Inference, + ctx: ChatCompletionContext, + response_id: str, + created_at: int, + text: OpenAIResponseText, + max_infer_iters: int, + tool_executor, # Will be the tool execution logic from the main class + ): + self.inference_api = inference_api + self.ctx = ctx + self.response_id = response_id + self.created_at = created_at + self.text = text + self.max_infer_iters = max_infer_iters + self.tool_executor = tool_executor + self.sequence_number = 0 + # Store MCP tool mapping that gets built during tool processing + self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {} + + async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]: + # Initialize output messages + output_messages: list[OpenAIResponseOutput] = [] + # Create initial response and emit response.created immediately + initial_response = OpenAIResponseObject( + created_at=self.created_at, + id=self.response_id, + model=self.ctx.model, + object="response", + status="in_progress", + output=output_messages.copy(), + text=self.text, + ) + + yield OpenAIResponseObjectStreamResponseCreated(response=initial_response) + + # Process all tools (including MCP tools) and emit streaming events + if self.ctx.response_tools: + async for stream_event in self._process_tools(self.ctx.response_tools, output_messages): + yield stream_event + + n_iter = 0 + messages = self.ctx.messages.copy() + + while True: + completion_result = await self.inference_api.openai_chat_completion( + model=self.ctx.model, + messages=messages, + tools=self.ctx.chat_tools, + stream=True, + temperature=self.ctx.temperature, + response_format=self.ctx.response_format, + ) + + # Process streaming chunks and build complete response + completion_result_data = None + async for stream_event_or_result in self._process_streaming_chunks(completion_result, output_messages): + if isinstance(stream_event_or_result, ChatCompletionResult): + completion_result_data = stream_event_or_result + else: + yield stream_event_or_result + if not completion_result_data: + raise ValueError("Streaming chunk processor failed to return completion data") + current_response = self._build_chat_completion(completion_result_data) + + function_tool_calls, non_function_tool_calls, next_turn_messages = self._separate_tool_calls( + current_response, messages + ) + + # Handle choices with no tool calls + for choice in current_response.choices: + if not (choice.message.tool_calls and self.ctx.response_tools): + output_messages.append(await convert_chat_choice_to_response_message(choice)) + + # Execute tool calls and coordinate results + async for stream_event in self._coordinate_tool_execution( + function_tool_calls, + non_function_tool_calls, + completion_result_data, + output_messages, + next_turn_messages, + ): + yield stream_event + + if not function_tool_calls and not non_function_tool_calls: + break + + if function_tool_calls: + logger.info("Exiting inference loop since there is a function (client-side) tool call") + break + + n_iter += 1 + if n_iter >= self.max_infer_iters: + logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}") + break + + messages = next_turn_messages + + # Create final response + final_response = OpenAIResponseObject( + created_at=self.created_at, + id=self.response_id, + model=self.ctx.model, + object="response", + status="completed", + text=self.text, + output=output_messages, + ) + + # Emit response.completed + yield OpenAIResponseObjectStreamResponseCompleted(response=final_response) + + def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list]: + """Separate tool calls into function and non-function categories.""" + function_tool_calls = [] + non_function_tool_calls = [] + next_turn_messages = messages.copy() + + for choice in current_response.choices: + next_turn_messages.append(choice.message) + + if choice.message.tool_calls and self.ctx.response_tools: + for tool_call in choice.message.tool_calls: + if is_function_tool_call(tool_call, self.ctx.response_tools): + function_tool_calls.append(tool_call) + else: + non_function_tool_calls.append(tool_call) + + return function_tool_calls, non_function_tool_calls, next_turn_messages + + async def _process_streaming_chunks( + self, completion_result, output_messages: list[OpenAIResponseOutput] + ) -> AsyncIterator[OpenAIResponseObjectStream | ChatCompletionResult]: + """Process streaming chunks and emit events, returning completion data.""" + # Initialize result tracking + chat_response_id = "" + chat_response_content = [] + chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {} + chunk_created = 0 + chunk_model = "" + chunk_finish_reason = "" + + # Create a placeholder message item for delta events + message_item_id = f"msg_{uuid.uuid4()}" + # Track tool call items for streaming events + tool_call_item_ids: dict[int, str] = {} + # Track content parts for streaming events + content_part_emitted = False + + async for chunk in completion_result: + chat_response_id = chunk.id + chunk_created = chunk.created + chunk_model = chunk.model + for chunk_choice in chunk.choices: + # Emit incremental text content as delta events + if chunk_choice.delta.content: + # Emit content_part.added event for first text chunk + if not content_part_emitted: + content_part_emitted = True + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseContentPartAdded( + response_id=self.response_id, + item_id=message_item_id, + part=OpenAIResponseContentPartOutputText( + text="", # Will be filled incrementally via text deltas + ), + sequence_number=self.sequence_number, + ) + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseOutputTextDelta( + content_index=0, + delta=chunk_choice.delta.content, + item_id=message_item_id, + output_index=0, + sequence_number=self.sequence_number, + ) + + # Collect content for final response + chat_response_content.append(chunk_choice.delta.content or "") + if chunk_choice.finish_reason: + chunk_finish_reason = chunk_choice.finish_reason + + # Aggregate tool call arguments across chunks + if chunk_choice.delta.tool_calls: + for tool_call in chunk_choice.delta.tool_calls: + response_tool_call = chat_response_tool_calls.get(tool_call.index, None) + # Create new tool call entry if this is the first chunk for this index + is_new_tool_call = response_tool_call is None + if is_new_tool_call: + tool_call_dict: dict[str, Any] = tool_call.model_dump() + tool_call_dict.pop("type", None) + response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict) + chat_response_tool_calls[tool_call.index] = response_tool_call + + # Create item ID for this tool call for streaming events + tool_call_item_id = f"fc_{uuid.uuid4()}" + tool_call_item_ids[tool_call.index] = tool_call_item_id + + # Emit output_item.added event for the new function call + self.sequence_number += 1 + function_call_item = OpenAIResponseOutputMessageFunctionToolCall( + arguments="", # Will be filled incrementally via delta events + call_id=tool_call.id or "", + name=tool_call.function.name if tool_call.function else "", + id=tool_call_item_id, + status="in_progress", + ) + yield OpenAIResponseObjectStreamResponseOutputItemAdded( + response_id=self.response_id, + item=function_call_item, + output_index=len(output_messages), + sequence_number=self.sequence_number, + ) + + # Stream tool call arguments as they arrive (differentiate between MCP and function calls) + if tool_call.function and tool_call.function.arguments: + tool_call_item_id = tool_call_item_ids[tool_call.index] + self.sequence_number += 1 + + # Check if this is an MCP tool call + is_mcp_tool = tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server + if is_mcp_tool: + # Emit MCP-specific argument delta event + yield OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta( + delta=tool_call.function.arguments, + item_id=tool_call_item_id, + output_index=len(output_messages), + sequence_number=self.sequence_number, + ) + else: + # Emit function call argument delta event + yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta( + delta=tool_call.function.arguments, + item_id=tool_call_item_id, + output_index=len(output_messages), + sequence_number=self.sequence_number, + ) + + # Accumulate arguments for final response (only for subsequent chunks) + if not is_new_tool_call: + response_tool_call.function.arguments = ( + response_tool_call.function.arguments or "" + ) + tool_call.function.arguments + + # Emit arguments.done events for completed tool calls (differentiate between MCP and function calls) + for tool_call_index in sorted(chat_response_tool_calls.keys()): + tool_call_item_id = tool_call_item_ids[tool_call_index] + final_arguments = chat_response_tool_calls[tool_call_index].function.arguments or "" + tool_call_name = chat_response_tool_calls[tool_call_index].function.name + + # Check if this is an MCP tool call + is_mcp_tool = tool_call_name and tool_call_name in self.mcp_tool_to_server + self.sequence_number += 1 + done_event_cls = ( + OpenAIResponseObjectStreamResponseMcpCallArgumentsDone + if is_mcp_tool + else OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone + ) + yield done_event_cls( + arguments=final_arguments, + item_id=tool_call_item_id, + output_index=len(output_messages), + sequence_number=self.sequence_number, + ) + + # Emit content_part.done event if text content was streamed (before content gets cleared) + if content_part_emitted: + final_text = "".join(chat_response_content) + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseContentPartDone( + response_id=self.response_id, + item_id=message_item_id, + part=OpenAIResponseContentPartOutputText( + text=final_text, + ), + sequence_number=self.sequence_number, + ) + + # Clear content when there are tool calls (OpenAI spec behavior) + if chat_response_tool_calls: + chat_response_content = [] + + yield ChatCompletionResult( + response_id=chat_response_id, + content=chat_response_content, + tool_calls=chat_response_tool_calls, + created=chunk_created, + model=chunk_model, + finish_reason=chunk_finish_reason, + message_item_id=message_item_id, + tool_call_item_ids=tool_call_item_ids, + content_part_emitted=content_part_emitted, + ) + + def _build_chat_completion(self, result: ChatCompletionResult) -> OpenAIChatCompletion: + """Build OpenAIChatCompletion from ChatCompletionResult.""" + # Convert collected chunks to complete response + if result.tool_calls: + tool_calls = [result.tool_calls[i] for i in sorted(result.tool_calls.keys())] + else: + tool_calls = None + + assistant_message = OpenAIAssistantMessageParam( + content=result.content_text, + tool_calls=tool_calls, + ) + return OpenAIChatCompletion( + id=result.response_id, + choices=[ + OpenAIChoice( + message=assistant_message, + finish_reason=result.finish_reason, + index=0, + ) + ], + created=result.created, + model=result.model, + ) + + async def _coordinate_tool_execution( + self, + function_tool_calls: list, + non_function_tool_calls: list, + completion_result_data: ChatCompletionResult, + output_messages: list[OpenAIResponseOutput], + next_turn_messages: list, + ) -> AsyncIterator[OpenAIResponseObjectStream]: + """Coordinate execution of both function and non-function tool calls.""" + # Execute non-function tool calls + for tool_call in non_function_tool_calls: + # Find the item_id for this tool call + matching_item_id = None + for index, item_id in completion_result_data.tool_call_item_ids.items(): + response_tool_call = completion_result_data.tool_calls.get(index) + if response_tool_call and response_tool_call.id == tool_call.id: + matching_item_id = item_id + break + + # Use a fallback item_id if not found + if not matching_item_id: + matching_item_id = f"tc_{uuid.uuid4()}" + + # Execute tool call with streaming + tool_call_log = None + tool_response_message = None + async for result in self.tool_executor.execute_tool_call( + tool_call, + self.ctx, + self.sequence_number, + len(output_messages), + matching_item_id, + self.mcp_tool_to_server, + ): + if result.stream_event: + # Forward streaming events + self.sequence_number = result.sequence_number + yield result.stream_event + + if result.final_output_message is not None: + tool_call_log = result.final_output_message + tool_response_message = result.final_input_message + self.sequence_number = result.sequence_number + + if tool_call_log: + output_messages.append(tool_call_log) + + # Emit output_item.done event for completed non-function tool call + if matching_item_id: + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseOutputItemDone( + response_id=self.response_id, + item=tool_call_log, + output_index=len(output_messages) - 1, + sequence_number=self.sequence_number, + ) + + if tool_response_message: + next_turn_messages.append(tool_response_message) + + # Execute function tool calls (client-side) + for tool_call in function_tool_calls: + # Find the item_id for this tool call from our tracking dictionary + matching_item_id = None + for index, item_id in completion_result_data.tool_call_item_ids.items(): + response_tool_call = completion_result_data.tool_calls.get(index) + if response_tool_call and response_tool_call.id == tool_call.id: + matching_item_id = item_id + break + + # Use existing item_id or create new one if not found + final_item_id = matching_item_id or f"fc_{uuid.uuid4()}" + + function_call_item = OpenAIResponseOutputMessageFunctionToolCall( + arguments=tool_call.function.arguments or "", + call_id=tool_call.id, + name=tool_call.function.name or "", + id=final_item_id, + status="completed", + ) + output_messages.append(function_call_item) + + # Emit output_item.done event for completed function call + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseOutputItemDone( + response_id=self.response_id, + item=function_call_item, + output_index=len(output_messages) - 1, + sequence_number=self.sequence_number, + ) + + async def _process_tools( + self, tools: list[OpenAIResponseInputTool], output_messages: list[OpenAIResponseOutput] + ) -> AsyncIterator[OpenAIResponseObjectStream]: + """Process all tools and emit appropriate streaming events.""" + from openai.types.chat import ChatCompletionToolParam + + from llama_stack.apis.tools import Tool + from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition + from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool + + def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam: + tool_def = ToolDefinition( + tool_name=tool_name, + description=tool.description, + parameters={ + param.name: ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, + ) + for param in tool.parameters + }, + ) + return convert_tooldef_to_openai_tool(tool_def) + + # Initialize chat_tools if not already set + if self.ctx.chat_tools is None: + self.ctx.chat_tools = [] + + for input_tool in tools: + if input_tool.type == "function": + self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump())) + elif input_tool.type in WebSearchToolTypes: + tool_name = "web_search" + # Need to access tool_groups_api from tool_executor + tool = await self.tool_executor.tool_groups_api.get_tool(tool_name) + if not tool: + raise ValueError(f"Tool {tool_name} not found") + self.ctx.chat_tools.append(make_openai_tool(tool_name, tool)) + elif input_tool.type == "file_search": + tool_name = "knowledge_search" + tool = await self.tool_executor.tool_groups_api.get_tool(tool_name) + if not tool: + raise ValueError(f"Tool {tool_name} not found") + self.ctx.chat_tools.append(make_openai_tool(tool_name, tool)) + elif input_tool.type == "mcp": + async for stream_event in self._process_mcp_tool(input_tool, output_messages): + yield stream_event + else: + raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}") + + async def _process_mcp_tool( + self, mcp_tool: OpenAIResponseInputToolMCP, output_messages: list[OpenAIResponseOutput] + ) -> AsyncIterator[OpenAIResponseObjectStream]: + """Process an MCP tool configuration and emit appropriate streaming events.""" + from llama_stack.providers.utils.tools.mcp import list_mcp_tools + + # Emit mcp_list_tools.in_progress + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseMcpListToolsInProgress( + sequence_number=self.sequence_number, + ) + + try: + # Parse allowed/never allowed tools + always_allowed = None + never_allowed = None + if mcp_tool.allowed_tools: + if isinstance(mcp_tool.allowed_tools, list): + always_allowed = mcp_tool.allowed_tools + elif isinstance(mcp_tool.allowed_tools, AllowedToolsFilter): + always_allowed = mcp_tool.allowed_tools.always + never_allowed = mcp_tool.allowed_tools.never + + # Call list_mcp_tools + tool_defs = await list_mcp_tools( + endpoint=mcp_tool.server_url, + headers=mcp_tool.headers or {}, + ) + + # Create the MCP list tools message + mcp_list_message = OpenAIResponseOutputMessageMCPListTools( + id=f"mcp_list_{uuid.uuid4()}", + server_label=mcp_tool.server_label, + tools=[], + ) + + # Process tools and update context + for t in tool_defs.data: + if never_allowed and t.name in never_allowed: + continue + if not always_allowed or t.name in always_allowed: + # Add to chat tools for inference + from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition + from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool + + tool_def = ToolDefinition( + tool_name=t.name, + description=t.description, + parameters={ + param.name: ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, + ) + for param in t.parameters + }, + ) + openai_tool = convert_tooldef_to_openai_tool(tool_def) + if self.ctx.chat_tools is None: + self.ctx.chat_tools = [] + self.ctx.chat_tools.append(openai_tool) + + # Add to MCP tool mapping + if t.name in self.mcp_tool_to_server: + raise ValueError(f"Duplicate tool name {t.name} found for server {mcp_tool.server_label}") + self.mcp_tool_to_server[t.name] = mcp_tool + + # Add to MCP list message + mcp_list_message.tools.append( + MCPListToolsTool( + name=t.name, + description=t.description, + input_schema={ + "type": "object", + "properties": { + p.name: { + "type": p.parameter_type, + "description": p.description, + } + for p in t.parameters + }, + "required": [p.name for p in t.parameters if p.required], + }, + ) + ) + + # Add the MCP list message to output + output_messages.append(mcp_list_message) + + # Emit output_item.added for the MCP list tools message + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseOutputItemAdded( + response_id=self.response_id, + item=mcp_list_message, + output_index=len(output_messages) - 1, + sequence_number=self.sequence_number, + ) + + # Emit mcp_list_tools.completed + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseMcpListToolsCompleted( + sequence_number=self.sequence_number, + ) + + # Emit output_item.done for the MCP list tools message + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseOutputItemDone( + response_id=self.response_id, + item=mcp_list_message, + output_index=len(output_messages) - 1, + sequence_number=self.sequence_number, + ) + + except Exception as e: + # TODO: Emit mcp_list_tools.failed event if needed + logger.exception(f"Failed to list MCP tools from {mcp_tool.server_url}: {e}") + raise diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py b/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py new file mode 100644 index 000000000..5b98b4f51 --- /dev/null +++ b/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py @@ -0,0 +1,379 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import json +from collections.abc import AsyncIterator + +from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseInputToolFileSearch, + OpenAIResponseInputToolMCP, + OpenAIResponseObjectStreamResponseMcpCallCompleted, + OpenAIResponseObjectStreamResponseMcpCallFailed, + OpenAIResponseObjectStreamResponseMcpCallInProgress, + OpenAIResponseObjectStreamResponseWebSearchCallCompleted, + OpenAIResponseObjectStreamResponseWebSearchCallInProgress, + OpenAIResponseObjectStreamResponseWebSearchCallSearching, + OpenAIResponseOutputMessageFileSearchToolCall, + OpenAIResponseOutputMessageFileSearchToolCallResults, + OpenAIResponseOutputMessageWebSearchToolCall, +) +from llama_stack.apis.common.content_types import ( + ImageContentItem, + TextContentItem, +) +from llama_stack.apis.inference import ( + OpenAIChatCompletionContentPartImageParam, + OpenAIChatCompletionContentPartTextParam, + OpenAIChatCompletionToolCall, + OpenAIImageURL, + OpenAIToolMessageParam, +) +from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime +from llama_stack.apis.vector_io import VectorIO +from llama_stack.log import get_logger + +from .types import ChatCompletionContext, ToolExecutionResult + +logger = get_logger(name=__name__, category="responses") + + +class ToolExecutor: + def __init__( + self, + tool_groups_api: ToolGroups, + tool_runtime_api: ToolRuntime, + vector_io_api: VectorIO, + ): + self.tool_groups_api = tool_groups_api + self.tool_runtime_api = tool_runtime_api + self.vector_io_api = vector_io_api + + async def execute_tool_call( + self, + tool_call: OpenAIChatCompletionToolCall, + ctx: ChatCompletionContext, + sequence_number: int, + output_index: int, + item_id: str, + mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, + ) -> AsyncIterator[ToolExecutionResult]: + tool_call_id = tool_call.id + function = tool_call.function + tool_kwargs = json.loads(function.arguments) if function.arguments else {} + + if not function or not tool_call_id or not function.name: + yield ToolExecutionResult(sequence_number=sequence_number) + return + + # Emit progress events for tool execution start + async for event_result in self._emit_progress_events( + function.name, ctx, sequence_number, output_index, item_id, mcp_tool_to_server + ): + sequence_number = event_result.sequence_number + yield event_result + + # Execute the actual tool call + error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server) + + # Emit completion events for tool execution + has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message)) + async for event_result in self._emit_completion_events( + function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server + ): + sequence_number = event_result.sequence_number + yield event_result + + # Build result messages from tool execution + output_message, input_message = await self._build_result_messages( + function, tool_call_id, tool_kwargs, ctx, error_exc, result, has_error, mcp_tool_to_server + ) + + # Yield the final result + yield ToolExecutionResult( + sequence_number=sequence_number, final_output_message=output_message, final_input_message=input_message + ) + + async def _execute_knowledge_search_via_vector_store( + self, + query: str, + response_file_search_tool: OpenAIResponseInputToolFileSearch, + ) -> ToolInvocationResult: + """Execute knowledge search using vector_stores.search API with filters support.""" + search_results = [] + + # Create search tasks for all vector stores + async def search_single_store(vector_store_id): + try: + search_response = await self.vector_io_api.openai_search_vector_store( + vector_store_id=vector_store_id, + query=query, + filters=response_file_search_tool.filters, + max_num_results=response_file_search_tool.max_num_results, + ranking_options=response_file_search_tool.ranking_options, + rewrite_query=False, + ) + return search_response.data + except Exception as e: + logger.warning(f"Failed to search vector store {vector_store_id}: {e}") + return [] + + # Run all searches in parallel using gather + search_tasks = [search_single_store(vid) for vid in response_file_search_tool.vector_store_ids] + all_results = await asyncio.gather(*search_tasks) + + # Flatten results + for results in all_results: + search_results.extend(results) + + # Convert search results to tool result format matching memory.py + # Format the results as interleaved content similar to memory.py + content_items = [] + content_items.append( + TextContentItem( + text=f"knowledge_search tool found {len(search_results)} chunks:\nBEGIN of knowledge_search tool results.\n" + ) + ) + + for i, result_item in enumerate(search_results): + chunk_text = result_item.content[0].text if result_item.content else "" + metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}" + if result_item.attributes: + metadata_text += f", attributes: {result_item.attributes}" + text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n" + content_items.append(TextContentItem(text=text_content)) + + content_items.append(TextContentItem(text="END of knowledge_search tool results.\n")) + content_items.append( + TextContentItem( + text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n', + ) + ) + + return ToolInvocationResult( + content=content_items, + metadata={ + "document_ids": [r.file_id for r in search_results], + "chunks": [r.content[0].text if r.content else "" for r in search_results], + "scores": [r.score for r in search_results], + }, + ) + + async def _emit_progress_events( + self, + function_name: str, + ctx: ChatCompletionContext, + sequence_number: int, + output_index: int, + item_id: str, + mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, + ) -> AsyncIterator[ToolExecutionResult]: + """Emit progress events for tool execution start.""" + # Emit in_progress event based on tool type (only for tools with specific streaming events) + progress_event = None + if mcp_tool_to_server and function_name in mcp_tool_to_server: + sequence_number += 1 + progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ) + elif function_name == "web_search": + sequence_number += 1 + progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ) + # Note: knowledge_search and other custom tools don't have specific streaming events in OpenAI spec + + if progress_event: + yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number) + + # For web search, emit searching event + if function_name == "web_search": + sequence_number += 1 + searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ) + yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number) + + async def _execute_tool( + self, + function_name: str, + tool_kwargs: dict, + ctx: ChatCompletionContext, + mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, + ) -> tuple[Exception | None, any]: + """Execute the tool and return error exception and result.""" + error_exc = None + result = None + + try: + if mcp_tool_to_server and function_name in mcp_tool_to_server: + from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool + + mcp_tool = mcp_tool_to_server[function_name] + result = await invoke_mcp_tool( + endpoint=mcp_tool.server_url, + headers=mcp_tool.headers or {}, + tool_name=function_name, + kwargs=tool_kwargs, + ) + elif function_name == "knowledge_search": + response_file_search_tool = next( + (t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), + None, + ) + if response_file_search_tool: + # Use vector_stores.search API instead of knowledge_search tool + # to support filters and ranking_options + query = tool_kwargs.get("query", "") + result = await self._execute_knowledge_search_via_vector_store( + query=query, + response_file_search_tool=response_file_search_tool, + ) + else: + result = await self.tool_runtime_api.invoke_tool( + tool_name=function_name, + kwargs=tool_kwargs, + ) + except Exception as e: + error_exc = e + + return error_exc, result + + async def _emit_completion_events( + self, + function_name: str, + ctx: ChatCompletionContext, + sequence_number: int, + output_index: int, + item_id: str, + has_error: bool, + mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, + ) -> AsyncIterator[ToolExecutionResult]: + """Emit completion or failure events for tool execution.""" + completion_event = None + + if mcp_tool_to_server and function_name in mcp_tool_to_server: + sequence_number += 1 + if has_error: + completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed( + sequence_number=sequence_number, + ) + else: + completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted( + sequence_number=sequence_number, + ) + elif function_name == "web_search": + sequence_number += 1 + completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ) + # Note: knowledge_search and other custom tools don't have specific completion events in OpenAI spec + + if completion_event: + yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number) + + async def _build_result_messages( + self, + function, + tool_call_id: str, + tool_kwargs: dict, + ctx: ChatCompletionContext, + error_exc: Exception | None, + result: any, + has_error: bool, + mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, + ) -> tuple[any, any]: + """Build output and input messages from tool execution results.""" + from llama_stack.providers.utils.inference.prompt_adapter import ( + interleaved_content_as_str, + ) + + # Build output message + if mcp_tool_to_server and function.name in mcp_tool_to_server: + from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseOutputMessageMCPCall, + ) + + message = OpenAIResponseOutputMessageMCPCall( + id=tool_call_id, + arguments=function.arguments, + name=function.name, + server_label=mcp_tool_to_server[function.name].server_label, + ) + if error_exc: + message.error = str(error_exc) + elif (result and result.error_code and result.error_code > 0) or (result and result.error_message): + message.error = f"Error (code {result.error_code}): {result.error_message}" + elif result and result.content: + message.output = interleaved_content_as_str(result.content) + else: + if function.name == "web_search": + message = OpenAIResponseOutputMessageWebSearchToolCall( + id=tool_call_id, + status="completed", + ) + if has_error: + message.status = "failed" + elif function.name == "knowledge_search": + message = OpenAIResponseOutputMessageFileSearchToolCall( + id=tool_call_id, + queries=[tool_kwargs.get("query", "")], + status="completed", + ) + if result and "document_ids" in result.metadata: + message.results = [] + for i, doc_id in enumerate(result.metadata["document_ids"]): + text = result.metadata["chunks"][i] if "chunks" in result.metadata else None + score = result.metadata["scores"][i] if "scores" in result.metadata else None + message.results.append( + OpenAIResponseOutputMessageFileSearchToolCallResults( + file_id=doc_id, + filename=doc_id, + text=text, + score=score, + attributes={}, + ) + ) + if has_error: + message.status = "failed" + else: + raise ValueError(f"Unknown tool {function.name} called") + + # Build input message + input_message = None + if result and result.content: + if isinstance(result.content, str): + content = result.content + elif isinstance(result.content, list): + content = [] + for item in result.content: + if isinstance(item, TextContentItem): + part = OpenAIChatCompletionContentPartTextParam(text=item.text) + elif isinstance(item, ImageContentItem): + if item.image.data: + url = f"data:image;base64,{item.image.data}" + else: + url = item.image.url + part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url)) + else: + raise ValueError(f"Unknown result content type: {type(item)}") + content.append(part) + else: + raise ValueError(f"Unknown result content type: {type(result.content)}") + input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id) + else: + text = str(error_exc) if error_exc else "Tool execution failed" + input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id) + + return message, input_message diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/types.py b/llama_stack/providers/inline/agents/meta_reference/responses/types.py new file mode 100644 index 000000000..89086c262 --- /dev/null +++ b/llama_stack/providers/inline/agents/meta_reference/responses/types.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from dataclasses import dataclass + +from openai.types.chat import ChatCompletionToolParam +from pydantic import BaseModel + +from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseInputTool, + OpenAIResponseObjectStream, + OpenAIResponseOutput, +) +from llama_stack.apis.inference import OpenAIChatCompletionToolCall, OpenAIMessageParam, OpenAIResponseFormatParam + + +class ToolExecutionResult(BaseModel): + """Result of streaming tool execution.""" + + stream_event: OpenAIResponseObjectStream | None = None + sequence_number: int + final_output_message: OpenAIResponseOutput | None = None + final_input_message: OpenAIMessageParam | None = None + + +@dataclass +class ChatCompletionResult: + """Result of processing streaming chat completion chunks.""" + + response_id: str + content: list[str] + tool_calls: dict[int, OpenAIChatCompletionToolCall] + created: int + model: str + finish_reason: str + message_item_id: str # For streaming events + tool_call_item_ids: dict[int, str] # For streaming events + content_part_emitted: bool # Tracking state + + @property + def content_text(self) -> str: + """Get joined content as string.""" + return "".join(self.content) + + @property + def has_tool_calls(self) -> bool: + """Check if there are any tool calls.""" + return bool(self.tool_calls) + + +class ChatCompletionContext(BaseModel): + model: str + messages: list[OpenAIMessageParam] + response_tools: list[OpenAIResponseInputTool] | None = None + chat_tools: list[ChatCompletionToolParam] | None = None + temperature: float | None + response_format: OpenAIResponseFormatParam diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py new file mode 100644 index 000000000..1507a55c8 --- /dev/null +++ b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py @@ -0,0 +1,169 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import uuid + +from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseInput, + OpenAIResponseInputFunctionToolCallOutput, + OpenAIResponseInputMessageContent, + OpenAIResponseInputMessageContentImage, + OpenAIResponseInputMessageContentText, + OpenAIResponseInputTool, + OpenAIResponseMessage, + OpenAIResponseOutputMessageContent, + OpenAIResponseOutputMessageContentOutputText, + OpenAIResponseOutputMessageFunctionToolCall, + OpenAIResponseText, +) +from llama_stack.apis.inference import ( + OpenAIAssistantMessageParam, + OpenAIChatCompletionContentPartImageParam, + OpenAIChatCompletionContentPartParam, + OpenAIChatCompletionContentPartTextParam, + OpenAIChatCompletionToolCall, + OpenAIChatCompletionToolCallFunction, + OpenAIChoice, + OpenAIDeveloperMessageParam, + OpenAIImageURL, + OpenAIJSONSchema, + OpenAIMessageParam, + OpenAIResponseFormatJSONObject, + OpenAIResponseFormatJSONSchema, + OpenAIResponseFormatParam, + OpenAIResponseFormatText, + OpenAISystemMessageParam, + OpenAIToolMessageParam, + OpenAIUserMessageParam, +) + + +async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage: + """Convert an OpenAI Chat Completion choice into an OpenAI Response output message.""" + output_content = "" + if isinstance(choice.message.content, str): + output_content = choice.message.content + elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam): + output_content = choice.message.content.text + else: + raise ValueError( + f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}" + ) + + return OpenAIResponseMessage( + id=f"msg_{uuid.uuid4()}", + content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)], + status="completed", + role="assistant", + ) + + +async def convert_response_content_to_chat_content( + content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]), +) -> str | list[OpenAIChatCompletionContentPartParam]: + """ + Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts. + + The content schemas of each API look similar, but are not exactly the same. + """ + if isinstance(content, str): + return content + + converted_parts = [] + for content_part in content: + if isinstance(content_part, OpenAIResponseInputMessageContentText): + converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text)) + elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText): + converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text)) + elif isinstance(content_part, OpenAIResponseInputMessageContentImage): + if content_part.image_url: + image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail) + converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url)) + elif isinstance(content_part, str): + converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part)) + else: + raise ValueError( + f"Llama Stack OpenAI Responses does not yet support content type '{type(content_part)}' in this context" + ) + return converted_parts + + +async def convert_response_input_to_chat_messages( + input: str | list[OpenAIResponseInput], +) -> list[OpenAIMessageParam]: + """ + Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages. + """ + messages: list[OpenAIMessageParam] = [] + if isinstance(input, list): + for input_item in input: + if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput): + messages.append( + OpenAIToolMessageParam( + content=input_item.output, + tool_call_id=input_item.call_id, + ) + ) + elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall): + tool_call = OpenAIChatCompletionToolCall( + index=0, + id=input_item.call_id, + function=OpenAIChatCompletionToolCallFunction( + name=input_item.name, + arguments=input_item.arguments, + ), + ) + messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call])) + else: + content = await convert_response_content_to_chat_content(input_item.content) + message_type = await get_message_type_by_role(input_item.role) + if message_type is None: + raise ValueError( + f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context" + ) + messages.append(message_type(content=content)) + else: + messages.append(OpenAIUserMessageParam(content=input)) + return messages + + +async def convert_response_text_to_chat_response_format( + text: OpenAIResponseText, +) -> OpenAIResponseFormatParam: + """ + Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format. + """ + if not text.format or text.format["type"] == "text": + return OpenAIResponseFormatText(type="text") + if text.format["type"] == "json_object": + return OpenAIResponseFormatJSONObject() + if text.format["type"] == "json_schema": + return OpenAIResponseFormatJSONSchema( + json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"]) + ) + raise ValueError(f"Unsupported text format: {text.format}") + + +async def get_message_type_by_role(role: str): + role_to_type = { + "user": OpenAIUserMessageParam, + "system": OpenAISystemMessageParam, + "assistant": OpenAIAssistantMessageParam, + "developer": OpenAIDeveloperMessageParam, + } + return role_to_type.get(role) + + +def is_function_tool_call( + tool_call: OpenAIChatCompletionToolCall, + tools: list[OpenAIResponseInputTool], +) -> bool: + if not tool_call.function: + return False + for t in tools: + if t.type == "function" and t.name == tool_call.function.name: + return True + return False diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index f83c39a6a..bae744010 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -22,7 +22,7 @@ from llama_stack.apis.safety import ( SafetyViolation, ViolationLevel, ) -from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults, OpenAICategories +from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults from llama_stack.apis.shields import Shield from llama_stack.core.datatypes import Api from llama_stack.models.llama.datatypes import Role @@ -72,30 +72,6 @@ SAFETY_CATEGORIES_TO_CODE_MAP = { } SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()} -OPENAI_TO_LLAMA_CATEGORIES_MAP = { - OpenAICategories.VIOLENCE: [CAT_VIOLENT_CRIMES], - OpenAICategories.VIOLENCE_GRAPHIC: [CAT_VIOLENT_CRIMES], - OpenAICategories.HARRASMENT: [CAT_CHILD_EXPLOITATION], - OpenAICategories.HARRASMENT_THREATENING: [CAT_VIOLENT_CRIMES, CAT_CHILD_EXPLOITATION], - OpenAICategories.HATE: [CAT_HATE], - OpenAICategories.HATE_THREATENING: [CAT_HATE, CAT_VIOLENT_CRIMES], - OpenAICategories.ILLICIT: [CAT_NON_VIOLENT_CRIMES], - OpenAICategories.ILLICIT_VIOLENT: [CAT_VIOLENT_CRIMES, CAT_INDISCRIMINATE_WEAPONS], - OpenAICategories.SEXUAL: [CAT_SEX_CRIMES, CAT_SEXUAL_CONTENT], - OpenAICategories.SEXUAL_MINORS: [CAT_CHILD_EXPLOITATION], - OpenAICategories.SELF_HARM: [CAT_SELF_HARM], - OpenAICategories.SELF_HARM_INTENT: [CAT_SELF_HARM], - OpenAICategories.SELF_HARM_INSTRUCTIONS: [CAT_SELF_HARM, CAT_SPECIALIZED_ADVICE], - # These are custom categories that are not in the OpenAI moderation categories - "custom/defamation": [CAT_DEFAMATION], - "custom/specialized_advice": [CAT_SPECIALIZED_ADVICE], - "custom/privacy_violation": [CAT_PRIVACY], - "custom/intellectual_property": [CAT_INTELLECTUAL_PROPERTY], - "custom/weapons": [CAT_INDISCRIMINATE_WEAPONS], - "custom/elections": [CAT_ELECTIONS], - "custom/code_interpreter_abuse": [CAT_CODE_INTERPRETER_ABUSE], -} - DEFAULT_LG_V3_SAFETY_CATEGORIES = [ CAT_VIOLENT_CRIMES, @@ -424,9 +400,9 @@ class LlamaGuardShield: ModerationObject with appropriate configuration """ # Set default values for safe case - categories = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), False) - category_scores = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), 1.0) - category_applied_input_types = {key: [] for key in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()} + categories = dict.fromkeys(SAFETY_CATEGORIES_TO_CODE_MAP.keys(), False) + category_scores = dict.fromkeys(SAFETY_CATEGORIES_TO_CODE_MAP.keys(), 1.0) + category_applied_input_types = {key: [] for key in SAFETY_CATEGORIES_TO_CODE_MAP.keys()} flagged = False user_message = None metadata = {} @@ -453,19 +429,15 @@ class LlamaGuardShield: ], ) - # Get OpenAI categories for the unsafe codes - openai_categories = [] - for code in unsafe_code_list: - llama_guard_category = SAFETY_CODE_TO_CATEGORIES_MAP[code] - openai_categories.extend( - k for k, v_l in OPENAI_TO_LLAMA_CATEGORIES_MAP.items() if llama_guard_category in v_l - ) + llama_guard_category = [SAFETY_CODE_TO_CATEGORIES_MAP[code] for code in unsafe_code_list] # Update categories for unsafe content - categories = {k: k in openai_categories for k in OPENAI_TO_LLAMA_CATEGORIES_MAP} - category_scores = {k: 1.0 if k in openai_categories else 0.0 for k in OPENAI_TO_LLAMA_CATEGORIES_MAP} + categories = {k: k in llama_guard_category for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys()} + category_scores = { + k: 1.0 if k in llama_guard_category else 0.0 for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys() + } category_applied_input_types = { - k: ["text"] if k in openai_categories else [] for k in OPENAI_TO_LLAMA_CATEGORIES_MAP + k: ["text"] if k in llama_guard_category else [] for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys() } flagged = True user_message = CANNED_RESPONSE_TEXT diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index 796771ee1..c760f0fd1 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -15,8 +15,10 @@ from llama_stack.apis.safety import ( RunShieldResponse, Safety, SafetyViolation, + ShieldStore, ViolationLevel, ) +from llama_stack.apis.safety.safety import ModerationObject from llama_stack.apis.shields import Shield from llama_stack.core.utils.model_utils import model_local_dir from llama_stack.providers.datatypes import ShieldsProtocolPrivate @@ -32,6 +34,8 @@ PROMPT_GUARD_MODEL = "Prompt-Guard-86M" class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): + shield_store: ShieldStore + def __init__(self, config: PromptGuardConfig, _deps) -> None: self.config = config @@ -53,7 +57,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): self, shield_id: str, messages: list[Message], - params: dict[str, Any] = None, + params: dict[str, Any], ) -> RunShieldResponse: shield = await self.shield_store.get_shield(shield_id) if not shield: @@ -61,6 +65,9 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): return await self.shield.run(messages) + async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject: + raise NotImplementedError("run_moderation is not implemented for Prompt Guard") + class PromptGuardShield: def __init__( @@ -117,8 +124,10 @@ class PromptGuardShield: elif self.config.guard_type == PromptGuardType.jailbreak.value and score_malicious > self.threshold: violation = SafetyViolation( violation_level=ViolationLevel.ERROR, - violation_type=f"prompt_injection:malicious={score_malicious}", - violation_return_message="Sorry, I cannot do this.", + user_message="Sorry, I cannot do this.", + metadata={ + "violation_type": f"prompt_injection:malicious={score_malicious}", + }, ) return RunShieldResponse(violation=violation) diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index 7a5373726..af61da59b 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -33,6 +33,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( + ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex, ) @@ -128,11 +129,12 @@ class FaissIndex(EmbeddingIndex): # Save updated index await self._save_index() - async def delete_chunk(self, chunk_id: str) -> None: - if chunk_id not in self.chunk_ids: + async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None: + chunk_ids = [c.chunk_id for c in chunks_for_deletion] + if not set(chunk_ids).issubset(self.chunk_ids): return - async with self.chunk_id_lock: + def remove_chunk(chunk_id: str): index = self.chunk_ids.index(chunk_id) self.index.remove_ids(np.array([index])) @@ -146,6 +148,10 @@ class FaissIndex(EmbeddingIndex): self.chunk_by_index = new_chunk_by_index self.chunk_ids.pop(index) + async with self.chunk_id_lock: + for chunk_id in chunk_ids: + remove_chunk(chunk_id) + await self._save_index() async def query_vector( @@ -174,7 +180,9 @@ class FaissIndex(EmbeddingIndex): k: int, score_threshold: float, ) -> QueryChunksResponse: - raise NotImplementedError("Keyword search is not supported in FAISS") + raise NotImplementedError( + "Keyword search is not supported - underlying DB FAISS does not support this search mode" + ) async def query_hybrid( self, @@ -185,7 +193,9 @@ class FaissIndex(EmbeddingIndex): reranker_type: str, reranker_params: dict[str, Any] | None = None, ) -> QueryChunksResponse: - raise NotImplementedError("Hybrid search is not supported in FAISS") + raise NotImplementedError( + "Hybrid search is not supported - underlying DB FAISS does not support this search mode" + ) class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): @@ -293,8 +303,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr return await index.query_chunks(query, params) - async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: - """Delete a chunk from a faiss index""" + async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: + """Delete chunks from a faiss index""" faiss_index = self.cache[store_id].index - for chunk_id in chunk_ids: - await faiss_index.delete_chunk(chunk_id) + await faiss_index.delete_chunks(chunks_for_deletion) diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index 1fff7b484..cc1982f3b 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -31,6 +31,7 @@ from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIV from llama_stack.providers.utils.memory.vector_store import ( RERANKER_TYPE_RRF, RERANKER_TYPE_WEIGHTED, + ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex, ) @@ -426,34 +427,36 @@ class SQLiteVecIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) - async def delete_chunk(self, chunk_id: str) -> None: + async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None: """Remove a chunk from the SQLite vector store.""" + chunk_ids = [c.chunk_id for c in chunks_for_deletion] - def _delete_chunk(): + def _delete_chunks(): connection = _create_sqlite_connection(self.db_path) cur = connection.cursor() try: cur.execute("BEGIN TRANSACTION") # Delete from metadata table - cur.execute(f"DELETE FROM {self.metadata_table} WHERE id = ?", (chunk_id,)) + placeholders = ",".join("?" * len(chunk_ids)) + cur.execute(f"DELETE FROM {self.metadata_table} WHERE id IN ({placeholders})", chunk_ids) # Delete from vector table - cur.execute(f"DELETE FROM {self.vector_table} WHERE id = ?", (chunk_id,)) + cur.execute(f"DELETE FROM {self.vector_table} WHERE id IN ({placeholders})", chunk_ids) # Delete from FTS table - cur.execute(f"DELETE FROM {self.fts_table} WHERE id = ?", (chunk_id,)) + cur.execute(f"DELETE FROM {self.fts_table} WHERE id IN ({placeholders})", chunk_ids) connection.commit() except Exception as e: connection.rollback() - logger.error(f"Error deleting chunk {chunk_id}: {e}") + logger.error(f"Error deleting chunks: {e}") raise finally: cur.close() connection.close() - await asyncio.to_thread(_delete_chunk) + await asyncio.to_thread(_delete_chunks) class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): @@ -551,12 +554,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc raise VectorStoreNotFoundError(vector_db_id) return await index.query_chunks(query, params) - async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: - """Delete a chunk from a sqlite_vec index.""" + async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: + """Delete chunks from a sqlite_vec index.""" index = await self._get_and_cache_vector_db_index(store_id) if not index: raise VectorStoreNotFoundError(store_id) - for chunk_id in chunk_ids: - # Use the index's delete_chunk method - await index.index.delete_chunk(chunk_id) + await index.index.delete_chunks(chunks_for_deletion) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index a8bc96a77..1801cdcad 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -213,6 +213,36 @@ def available_providers() -> list[ProviderSpec]: description="Google Gemini inference provider for accessing Gemini models and Google's AI services.", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="vertexai", + pip_packages=["litellm", "google-cloud-aiplatform"], + module="llama_stack.providers.remote.inference.vertexai", + config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig", + provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator", + description="""Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages: + +β€’ Enterprise-grade security: Uses Google Cloud's security controls and IAM +β€’ Better integration: Seamless integration with other Google Cloud services +β€’ Advanced features: Access to additional Vertex AI features like model tuning and monitoring +β€’ Authentication: Uses Google Cloud Application Default Credentials (ADC) instead of API keys + +Configuration: +- Set VERTEX_AI_PROJECT environment variable (required) +- Set VERTEX_AI_LOCATION environment variable (optional, defaults to us-central1) +- Use Google Cloud Application Default Credentials or service account key + +Authentication Setup: +Option 1 (Recommended): gcloud auth application-default login +Option 2: Set GOOGLE_APPLICATION_CREDENTIALS to service account key path + +Available Models: +- vertex_ai/gemini-2.0-flash +- vertex_ai/gemini-2.5-flash +- vertex_ai/gemini-2.5-pro""", + ), + ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index 846f7b88e..70148eb15 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -45,6 +45,18 @@ That means you'll get fast and efficient vector retrieval. - Lightweight and easy to use - Fully integrated with Llama Stack - GPU support +- **Vector search** - FAISS supports pure vector similarity search using embeddings + +## Search Modes + +**Supported:** +- **Vector Search** (`mode="vector"`): Performs vector similarity search using embeddings + +**Not Supported:** +- **Keyword Search** (`mode="keyword"`): Not supported by FAISS +- **Hybrid Search** (`mode="hybrid"`): Not supported by FAISS + +> **Note**: FAISS is designed as a pure vector similarity search library. See the [FAISS GitHub repository](https://github.com/facebookresearch/faiss) for more details about FAISS's core functionality. ## Usage @@ -330,6 +342,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti """, ), api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], ), InlineProviderSpec( api=Api.vector_io, @@ -338,6 +351,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti module="llama_stack.providers.inline.vector_io.chroma", config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig", api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], description=""" [Chroma](https://www.trychroma.com/) is an inline and remote vector database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database. @@ -452,6 +466,7 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more """, ), api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], ), InlineProviderSpec( api=Api.vector_io, @@ -535,6 +550,7 @@ That means you're not limited to storing vectors in memory or in a separate serv - Easy to use - Fully integrated with Llama Stack +- Supports all search modes: vector, keyword, and hybrid search (both inline and remote configurations) ## Usage @@ -625,6 +641,92 @@ vector_io: - **`client_pem_path`**: Path to the **client certificate** file (required for mTLS). - **`client_key_path`**: Path to the **client private key** file (required for mTLS). +## Search Modes + +Milvus supports three different search modes for both inline and remote configurations: + +### Vector Search +Vector search uses semantic similarity to find the most relevant chunks based on embedding vectors. This is the default search mode and works well for finding conceptually similar content. + +```python +# Vector search example +search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="What is machine learning?", + search_mode="vector", + max_num_results=5, +) +``` + +### Keyword Search +Keyword search uses traditional text-based matching to find chunks containing specific terms or phrases. This is useful when you need exact term matches. + +```python +# Keyword search example +search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="Python programming language", + search_mode="keyword", + max_num_results=5, +) +``` + +### Hybrid Search +Hybrid search combines both vector and keyword search methods to provide more comprehensive results. It leverages the strengths of both semantic similarity and exact term matching. + +#### Basic Hybrid Search +```python +# Basic hybrid search example (uses RRF ranker with default impact_factor=60.0) +search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="neural networks in Python", + search_mode="hybrid", + max_num_results=5, +) +``` + +**Note**: The default `impact_factor` value of 60.0 was empirically determined to be optimal in the original RRF research paper: ["Reciprocal Rank Fusion outperforms Condorcet and individual Rank Learning Methods"](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf) (Cormack et al., 2009). + +#### Hybrid Search with RRF (Reciprocal Rank Fusion) Ranker +RRF combines rankings from vector and keyword search by using reciprocal ranks. The impact factor controls how much weight is given to higher-ranked results. + +```python +# Hybrid search with custom RRF parameters +search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="neural networks in Python", + search_mode="hybrid", + max_num_results=5, + ranking_options={ + "ranker": { + "type": "rrf", + "impact_factor": 100.0, # Higher values give more weight to top-ranked results + } + }, +) +``` + +#### Hybrid Search with Weighted Ranker +Weighted ranker linearly combines normalized scores from vector and keyword search. The alpha parameter controls the balance between the two search methods. + +```python +# Hybrid search with weighted ranker +search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="neural networks in Python", + search_mode="hybrid", + max_num_results=5, + ranking_options={ + "ranker": { + "type": "weighted", + "alpha": 0.7, # 70% vector search, 30% keyword search + } + }, +) +``` + +For detailed documentation on RRF and Weighted rankers, please refer to the [Milvus Reranking Guide](https://milvus.io/docs/reranking.md). + ## Documentation See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general. @@ -632,6 +734,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi """, ), api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], ), InlineProviderSpec( api=Api.vector_io, diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index ca4c7b578..bd86f7238 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -235,6 +235,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv llama_model = self.get_llama_model(request.model) if isinstance(request, ChatCompletionRequest): + # TODO: tools are never added to the request, so we need to add them here if media_present or not llama_model: input_dict["messages"] = [ await convert_message_to_openai_dict(m, download=True) for m in request.messages @@ -378,6 +379,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv # Fireworks chat completions OpenAI-compatible API does not support # tool calls properly. llama_model = self.get_llama_model(model_obj.provider_resource_id) + if llama_model: return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion( self, @@ -431,4 +433,5 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv user=user, ) + logger.debug(f"fireworks params: {params}") return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 26b4dec76..a93421536 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -457,9 +457,6 @@ class OllamaInferenceAdapter( user: str | None = None, ) -> OpenAIEmbeddingsResponse: model_obj = await self._get_model(model) - if model_obj.model_type != ModelType.embedding: - raise ValueError(f"Model {model} is not an embedding model") - if model_obj.provider_resource_id is None: raise ValueError(f"Model {model} has no provider_resource_id set") diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index a5bb079ef..323831845 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -308,9 +308,7 @@ class TGIAdapter(_HfAdapter): if not config.url: raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.") log.info(f"Initializing TGI client with url={config.url}") - self.client = AsyncInferenceClient( - model=config.url, - ) + self.client = AsyncInferenceClient(model=config.url, provider="hf-inference") endpoint_info = await self.client.get_endpoint_info() self.max_tokens = endpoint_info["max_total_tokens"] self.model_id = endpoint_info["model_id"] diff --git a/llama_stack/providers/remote/inference/vertexai/__init__.py b/llama_stack/providers/remote/inference/vertexai/__init__.py new file mode 100644 index 000000000..d9e9419be --- /dev/null +++ b/llama_stack/providers/remote/inference/vertexai/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import VertexAIConfig + + +async def get_adapter_impl(config: VertexAIConfig, _deps): + from .vertexai import VertexAIInferenceAdapter + + impl = VertexAIInferenceAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/inference/vertexai/config.py b/llama_stack/providers/remote/inference/vertexai/config.py new file mode 100644 index 000000000..659de653e --- /dev/null +++ b/llama_stack/providers/remote/inference/vertexai/config.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from pydantic import BaseModel, Field + +from llama_stack.schema_utils import json_schema_type + + +class VertexAIProviderDataValidator(BaseModel): + vertex_project: str | None = Field( + default=None, + description="Google Cloud project ID for Vertex AI", + ) + vertex_location: str | None = Field( + default=None, + description="Google Cloud location for Vertex AI (e.g., us-central1)", + ) + + +@json_schema_type +class VertexAIConfig(BaseModel): + project: str = Field( + description="Google Cloud project ID for Vertex AI", + ) + location: str = Field( + default="us-central1", + description="Google Cloud location for Vertex AI", + ) + + @classmethod + def sample_run_config( + cls, + project: str = "${env.VERTEX_AI_PROJECT:=}", + location: str = "${env.VERTEX_AI_LOCATION:=us-central1}", + **kwargs, + ) -> dict[str, Any]: + return { + "project": project, + "location": location, + } diff --git a/llama_stack/providers/remote/inference/vertexai/models.py b/llama_stack/providers/remote/inference/vertexai/models.py new file mode 100644 index 000000000..e72db533d --- /dev/null +++ b/llama_stack/providers/remote/inference/vertexai/models.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.providers.utils.inference.model_registry import ( + ProviderModelEntry, +) + +# Vertex AI model IDs with vertex_ai/ prefix as required by litellm +LLM_MODEL_IDS = [ + "vertex_ai/gemini-2.0-flash", + "vertex_ai/gemini-2.5-flash", + "vertex_ai/gemini-2.5-pro", +] + +SAFETY_MODELS_ENTRIES = list[ProviderModelEntry]() + +MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + SAFETY_MODELS_ENTRIES diff --git a/llama_stack/providers/remote/inference/vertexai/vertexai.py b/llama_stack/providers/remote/inference/vertexai/vertexai.py new file mode 100644 index 000000000..8807fd0e6 --- /dev/null +++ b/llama_stack/providers/remote/inference/vertexai/vertexai.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.inference import ChatCompletionRequest +from llama_stack.providers.utils.inference.litellm_openai_mixin import ( + LiteLLMOpenAIMixin, +) + +from .config import VertexAIConfig +from .models import MODEL_ENTRIES + + +class VertexAIInferenceAdapter(LiteLLMOpenAIMixin): + def __init__(self, config: VertexAIConfig) -> None: + LiteLLMOpenAIMixin.__init__( + self, + MODEL_ENTRIES, + litellm_provider_name="vertex_ai", + api_key_from_config=None, # Vertex AI uses ADC, not API keys + provider_data_api_key_field="vertex_project", # Use project for validation + ) + self.config = config + + def get_api_key(self) -> str: + # Vertex AI doesn't use API keys, it uses Application Default Credentials + # Return empty string to let litellm handle authentication via ADC + return "" + + async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]: + # Get base parameters from parent + params = await super()._get_params(request) + + # Add Vertex AI specific parameters + provider_data = self.get_request_provider_data() + if provider_data: + if getattr(provider_data, "vertex_project", None): + params["vertex_project"] = provider_data.vertex_project + if getattr(provider_data, "vertex_location", None): + params["vertex_location"] = provider_data.vertex_location + else: + params["vertex_project"] = self.config.project + params["vertex_location"] = self.config.location + + # Remove api_key since Vertex AI uses ADC + params.pop("api_key", None) + + return params diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 26aeaedfb..8f252711b 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -26,6 +26,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( + ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex, ) @@ -115,8 +116,10 @@ class ChromaIndex(EmbeddingIndex): ) -> QueryChunksResponse: raise NotImplementedError("Keyword search is not supported in Chroma") - async def delete_chunk(self, chunk_id: str) -> None: - raise NotImplementedError("delete_chunk is not supported in Chroma") + async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None: + """Delete a single chunk from the Chroma collection by its ID.""" + ids = [f"{chunk.document_id}:{chunk.chunk_id}" for chunk in chunks_for_deletion] + await maybe_await(self.collection.delete(ids=ids)) async def query_hybrid( self, @@ -144,6 +147,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP self.cache = {} self.kvstore: KVStore | None = None self.vector_db_store = None + self.files_api = files_api async def initialize(self) -> None: self.kvstore = await kvstore_impl(self.config.kvstore) @@ -227,5 +231,10 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP self.cache[vector_db_id] = index return index - async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") + async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: + """Delete chunks from a Chroma vector store.""" + index = await self._get_and_cache_vector_db_index(store_id) + if not index: + raise ValueError(f"Vector DB {store_id} not found") + + await index.index.delete_chunks(chunks_for_deletion) diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index b09edb65c..0eaae81b3 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -28,6 +28,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( RERANKER_TYPE_WEIGHTED, + ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex, ) @@ -287,14 +288,17 @@ class MilvusIndex(EmbeddingIndex): return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores) - async def delete_chunk(self, chunk_id: str) -> None: + async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None: """Remove a chunk from the Milvus collection.""" + chunk_ids = [c.chunk_id for c in chunks_for_deletion] try: + # Use IN clause with square brackets and single quotes for VARCHAR field + chunk_ids_str = ", ".join(f"'{chunk_id}'" for chunk_id in chunk_ids) await asyncio.to_thread( - self.client.delete, collection_name=self.collection_name, filter=f'chunk_id == "{chunk_id}"' + self.client.delete, collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]" ) except Exception as e: - logger.error(f"Error deleting chunk {chunk_id} from Milvus collection {self.collection_name}: {e}") + logger.error(f"Error deleting chunks from Milvus collection {self.collection_name}: {e}") raise @@ -420,12 +424,10 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP return await index.query_chunks(query, params) - async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: """Delete a chunk from a milvus vector store.""" index = await self._get_and_cache_vector_db_index(store_id) if not index: raise VectorStoreNotFoundError(store_id) - for chunk_id in chunk_ids: - # Use the index's delete_chunk method - await index.index.delete_chunk(chunk_id) + await index.index.delete_chunks(chunks_for_deletion) diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index b1645ac5a..d2a5d910b 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -27,6 +27,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( + ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex, ) @@ -163,10 +164,11 @@ class PGVectorIndex(EmbeddingIndex): with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") - async def delete_chunk(self, chunk_id: str) -> None: + async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None: """Remove a chunk from the PostgreSQL table.""" + chunk_ids = [c.chunk_id for c in chunks_for_deletion] with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: - cur.execute(f"DELETE FROM {self.table_name} WHERE id = %s", (chunk_id,)) + cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids,)) class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): @@ -275,12 +277,10 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api) return self.cache[vector_db_id] - async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: """Delete a chunk from a PostgreSQL vector store.""" index = await self._get_and_cache_vector_db_index(store_id) if not index: raise VectorStoreNotFoundError(store_id) - for chunk_id in chunk_ids: - # Use the index's delete_chunk method - await index.index.delete_chunk(chunk_id) + await index.index.delete_chunks(chunks_for_deletion) diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 144da0f4f..018015780 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -29,6 +29,7 @@ from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig a from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( + ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex, ) @@ -88,15 +89,16 @@ class QdrantIndex(EmbeddingIndex): await self.client.upsert(collection_name=self.collection_name, points=points) - async def delete_chunk(self, chunk_id: str) -> None: + async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None: """Remove a chunk from the Qdrant collection.""" + chunk_ids = [convert_id(c.chunk_id) for c in chunks_for_deletion] try: await self.client.delete( collection_name=self.collection_name, - points_selector=models.PointIdsList(points=[convert_id(chunk_id)]), + points_selector=models.PointIdsList(points=chunk_ids), ) except Exception as e: - log.error(f"Error deleting chunk {chunk_id} from Qdrant collection {self.collection_name}: {e}") + log.error(f"Error deleting chunks from Qdrant collection {self.collection_name}: {e}") raise async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: @@ -264,12 +266,14 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP ) -> VectorStoreFileObject: # Qdrant doesn't allow multiple clients to access the same storage path simultaneously. async with self._qdrant_lock: - await super().openai_attach_file_to_vector_store(vector_store_id, file_id, attributes, chunking_strategy) + return await super().openai_attach_file_to_vector_store( + vector_store_id, file_id, attributes, chunking_strategy + ) - async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: """Delete chunks from a Qdrant vector store.""" index = await self._get_and_cache_vector_db_index(store_id) if not index: raise ValueError(f"Vector DB {store_id} not found") - for chunk_id in chunk_ids: - await index.index.delete_chunk(chunk_id) + + await index.index.delete_chunks(chunks_for_deletion) diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index 11da8902c..966724848 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -26,6 +26,7 @@ from llama_stack.providers.utils.memory.openai_vector_store_mixin import ( OpenAIVectorStoreMixin, ) from llama_stack.providers.utils.memory.vector_store import ( + ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex, ) @@ -67,6 +68,7 @@ class WeaviateIndex(EmbeddingIndex): data_objects.append( wvc.data.DataObject( properties={ + "chunk_id": chunk.chunk_id, "chunk_content": chunk.model_dump_json(), }, vector=embeddings[i].tolist(), @@ -79,10 +81,11 @@ class WeaviateIndex(EmbeddingIndex): # TODO: make this async friendly collection.data.insert_many(data_objects) - async def delete_chunk(self, chunk_id: str) -> None: + async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None: sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True) collection = self.client.collections.get(sanitized_collection_name) - collection.data.delete_many(where=Filter.by_property("id").contains_any([chunk_id])) + chunk_ids = [chunk.chunk_id for chunk in chunks_for_deletion] + collection.data.delete_many(where=Filter.by_property("chunk_id").contains_any(chunk_ids)) async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True) @@ -307,10 +310,10 @@ class WeaviateVectorIOAdapter( return await index.query_chunks(query, params) - async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: sanitized_collection_name = sanitize_collection_name(store_id, weaviate_format=True) index = await self._get_and_cache_vector_db_index(sanitized_collection_name) if not index: raise ValueError(f"Vector DB {sanitized_collection_name} not found") - await index.delete(chunk_ids) + await index.index.delete_chunks(chunks_for_deletion) diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index e6e5ccc8a..6297cc2ed 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -31,15 +31,15 @@ from openai.types.chat import ( from openai.types.chat import ( ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam, ) +from openai.types.chat import ( + ChatCompletionMessageFunctionToolCall as OpenAIChatCompletionMessageFunctionToolCall, +) from openai.types.chat import ( ChatCompletionMessageParam as OpenAIChatCompletionMessage, ) from openai.types.chat import ( ChatCompletionMessageToolCall, ) -from openai.types.chat import ( - ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall, -) from openai.types.chat import ( ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, ) @@ -70,7 +70,7 @@ from openai.types.chat.chat_completion_chunk import ( from openai.types.chat.chat_completion_content_part_image_param import ( ImageURL as OpenAIImageURL, ) -from openai.types.chat.chat_completion_message_tool_call_param import ( +from openai.types.chat.chat_completion_message_tool_call import ( Function as OpenAIFunction, ) from pydantic import BaseModel @@ -633,7 +633,7 @@ async def convert_message_to_openai_dict_new( ) elif isinstance(message, CompletionMessage): tool_calls = [ - OpenAIChatCompletionMessageToolCall( + OpenAIChatCompletionMessageFunctionToolCall( id=tool.call_id, function=OpenAIFunction( name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value), @@ -903,7 +903,7 @@ def _convert_openai_request_response_format( def _convert_openai_tool_calls( - tool_calls: list[OpenAIChatCompletionMessageToolCall], + tool_calls: list[OpenAIChatCompletionMessageFunctionToolCall], ) -> list[ToolCall]: """ Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall. diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index 7b6e69df1..120d0d4fc 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -6,7 +6,6 @@ import asyncio import json -import logging import mimetypes import time import uuid @@ -37,10 +36,15 @@ from llama_stack.apis.vector_io import ( VectorStoreSearchResponse, VectorStoreSearchResponsePage, ) +from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore.api import KVStore -from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks +from llama_stack.providers.utils.memory.vector_store import ( + ChunkForDeletion, + content_from_data_and_mime_type, + make_overlapped_chunks, +) -logger = logging.getLogger(__name__) +logger = get_logger(__name__, category="vector_io") # Constants for OpenAI vector stores CHUNK_MULTIPLIER = 5 @@ -154,8 +158,8 @@ class OpenAIVectorStoreMixin(ABC): self.openai_vector_stores = await self._load_openai_vector_stores() @abstractmethod - async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: - """Delete a chunk from a vector store.""" + async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: + """Delete chunks from a vector store.""" pass @abstractmethod @@ -614,7 +618,7 @@ class OpenAIVectorStoreMixin(ABC): ) vector_store_file_object.status = "completed" except Exception as e: - logger.error(f"Error attaching file to vector store: {e}") + logger.exception("Error attaching file to vector store") vector_store_file_object.status = "failed" vector_store_file_object.last_error = VectorStoreFileLastError( code="server_error", @@ -767,7 +771,21 @@ class OpenAIVectorStoreMixin(ABC): dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id) chunks = [Chunk.model_validate(c) for c in dict_chunks] - await self.delete_chunks(vector_store_id, [str(c.chunk_id) for c in chunks if c.chunk_id]) + + # Create ChunkForDeletion objects with both chunk_id and document_id + chunks_for_deletion = [] + for c in chunks: + if c.chunk_id: + document_id = c.metadata.get("document_id") or ( + c.chunk_metadata.document_id if c.chunk_metadata else None + ) + if document_id: + chunks_for_deletion.append(ChunkForDeletion(chunk_id=str(c.chunk_id), document_id=document_id)) + else: + logger.warning(f"Chunk {c.chunk_id} has no document_id, skipping deletion") + + if chunks_for_deletion: + await self.delete_chunks(vector_store_id, chunks_for_deletion) store_info = self.openai_vector_stores[vector_store_id].copy() diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index bb9002f30..6ae5bb521 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -16,6 +16,7 @@ from urllib.parse import unquote import httpx import numpy as np from numpy.typing import NDArray +from pydantic import BaseModel from llama_stack.apis.common.content_types import ( URL, @@ -34,6 +35,18 @@ from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id log = logging.getLogger(__name__) + +class ChunkForDeletion(BaseModel): + """Information needed to delete a chunk from a vector store. + + :param chunk_id: The ID of the chunk to delete + :param document_id: The ID of the document this chunk belongs to + """ + + chunk_id: str + document_id: str + + # Constants for reranker types RERANKER_TYPE_RRF = "rrf" RERANKER_TYPE_WEIGHTED = "weighted" @@ -232,7 +245,7 @@ class EmbeddingIndex(ABC): raise NotImplementedError() @abstractmethod - async def delete_chunk(self, chunk_id: str): + async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]): raise NotImplementedError() @abstractmethod diff --git a/llama_stack/ui/.nvmrc b/llama_stack/ui/.nvmrc new file mode 100644 index 000000000..1384ff6a1 --- /dev/null +++ b/llama_stack/ui/.nvmrc @@ -0,0 +1 @@ +22.5.1 diff --git a/llama_stack/ui/.prettierignore b/llama_stack/ui/.prettierignore index 1b8ac8894..b737ae6ed 100644 --- a/llama_stack/ui/.prettierignore +++ b/llama_stack/ui/.prettierignore @@ -1,3 +1,12 @@ # Ignore artifacts: build coverage +.next +node_modules +dist +*.lock +*.log + +# Generated files +*.min.js +*.min.css diff --git a/llama_stack/ui/.prettierrc b/llama_stack/ui/.prettierrc index 0967ef424..059475a24 100644 --- a/llama_stack/ui/.prettierrc +++ b/llama_stack/ui/.prettierrc @@ -1 +1,10 @@ -{} +{ + "semi": true, + "trailingComma": "es5", + "singleQuote": false, + "printWidth": 80, + "tabWidth": 2, + "useTabs": false, + "bracketSpacing": true, + "arrowParens": "avoid" +} diff --git a/llama_stack/ui/app/api/v1/[...path]/route.ts b/llama_stack/ui/app/api/v1/[...path]/route.ts index 1959f9099..51c1f8004 100644 --- a/llama_stack/ui/app/api/v1/[...path]/route.ts +++ b/llama_stack/ui/app/api/v1/[...path]/route.ts @@ -47,7 +47,7 @@ async function proxyRequest(request: NextRequest, method: string) { const responseText = await response.text(); console.log( - `Response from FastAPI: ${response.status} ${response.statusText}`, + `Response from FastAPI: ${response.status} ${response.statusText}` ); // Create response with same status and headers @@ -74,7 +74,7 @@ async function proxyRequest(request: NextRequest, method: string) { backend_url: BACKEND_URL, timestamp: new Date().toISOString(), }, - { status: 500 }, + { status: 500 } ); } } diff --git a/llama_stack/ui/app/auth/signin/page.tsx b/llama_stack/ui/app/auth/signin/page.tsx index c9510fd6b..0ccb4a397 100644 --- a/llama_stack/ui/app/auth/signin/page.tsx +++ b/llama_stack/ui/app/auth/signin/page.tsx @@ -51,9 +51,9 @@ export default function SignInPage() { onClick={() => { console.log("Signing in with GitHub..."); signIn("github", { callbackUrl: "/auth/signin" }).catch( - (error) => { + error => { console.error("Sign in error:", error); - }, + } ); }} className="w-full" diff --git a/llama_stack/ui/app/chat-playground/page.tsx b/llama_stack/ui/app/chat-playground/page.tsx index c31248b78..b8651aca0 100644 --- a/llama_stack/ui/app/chat-playground/page.tsx +++ b/llama_stack/ui/app/chat-playground/page.tsx @@ -29,14 +29,13 @@ export default function ChatPlaygroundPage() { const isModelsLoading = modelsLoading ?? true; - useEffect(() => { const fetchModels = async () => { try { setModelsLoading(true); setModelsError(null); const modelList = await client.models.list(); - const llmModels = modelList.filter(model => model.model_type === 'llm'); + const llmModels = modelList.filter(model => model.model_type === "llm"); setModels(llmModels); if (llmModels.length > 0) { setSelectedModel(llmModels[0].identifier); @@ -53,103 +52,122 @@ export default function ChatPlaygroundPage() { }, [client]); const extractTextContent = (content: unknown): string => { - if (typeof content === 'string') { + if (typeof content === "string") { return content; } if (Array.isArray(content)) { return content - .filter(item => item && typeof item === 'object' && 'type' in item && item.type === 'text') - .map(item => (item && typeof item === 'object' && 'text' in item) ? String(item.text) : '') - .join(''); + .filter( + item => + item && + typeof item === "object" && + "type" in item && + item.type === "text" + ) + .map(item => + item && typeof item === "object" && "text" in item + ? String(item.text) + : "" + ) + .join(""); } - if (content && typeof content === 'object' && 'type' in content && content.type === 'text' && 'text' in content) { - return String(content.text) || ''; + if ( + content && + typeof content === "object" && + "type" in content && + content.type === "text" && + "text" in content + ) { + return String(content.text) || ""; } - return ''; + return ""; }; const handleInputChange = (e: React.ChangeEvent) => { setInput(e.target.value); }; -const handleSubmit = async (event?: { preventDefault?: () => void }) => { - event?.preventDefault?.(); - if (!input.trim()) return; + const handleSubmit = async (event?: { preventDefault?: () => void }) => { + event?.preventDefault?.(); + if (!input.trim()) return; - // Add user message to chat - const userMessage: Message = { - id: Date.now().toString(), - role: "user", - content: input.trim(), - createdAt: new Date(), - }; - - setMessages(prev => [...prev, userMessage]); - setInput(""); - - // Use the helper function with the content - await handleSubmitWithContent(userMessage.content); -}; - -const handleSubmitWithContent = async (content: string) => { - setIsGenerating(true); - setError(null); - - try { - const messageParams: CompletionCreateParams["messages"] = [ - ...messages.map(msg => { - const msgContent = typeof msg.content === 'string' ? msg.content : extractTextContent(msg.content); - if (msg.role === "user") { - return { role: "user" as const, content: msgContent }; - } else if (msg.role === "assistant") { - return { role: "assistant" as const, content: msgContent }; - } else { - return { role: "system" as const, content: msgContent }; - } - }), - { role: "user" as const, content } - ]; - - const response = await client.chat.completions.create({ - model: selectedModel, - messages: messageParams, - stream: true, - }); - - const assistantMessage: Message = { - id: (Date.now() + 1).toString(), - role: "assistant", - content: "", + // Add user message to chat + const userMessage: Message = { + id: Date.now().toString(), + role: "user", + content: input.trim(), createdAt: new Date(), }; - setMessages(prev => [...prev, assistantMessage]); - let fullContent = ""; - for await (const chunk of response) { - if (chunk.choices && chunk.choices[0]?.delta?.content) { - const deltaContent = chunk.choices[0].delta.content; - fullContent += deltaContent; + setMessages(prev => [...prev, userMessage]); + setInput(""); - flushSync(() => { - setMessages(prev => { - const newMessages = [...prev]; - const lastMessage = newMessages[newMessages.length - 1]; - if (lastMessage.role === "assistant") { - lastMessage.content = fullContent; - } - return newMessages; + // Use the helper function with the content + await handleSubmitWithContent(userMessage.content); + }; + + const handleSubmitWithContent = async (content: string) => { + setIsGenerating(true); + setError(null); + + try { + const messageParams: CompletionCreateParams["messages"] = [ + ...messages.map(msg => { + const msgContent = + typeof msg.content === "string" + ? msg.content + : extractTextContent(msg.content); + if (msg.role === "user") { + return { role: "user" as const, content: msgContent }; + } else if (msg.role === "assistant") { + return { role: "assistant" as const, content: msgContent }; + } else { + return { role: "system" as const, content: msgContent }; + } + }), + { role: "user" as const, content }, + ]; + + const response = await client.chat.completions.create({ + model: selectedModel, + messages: messageParams, + stream: true, + }); + + const assistantMessage: Message = { + id: (Date.now() + 1).toString(), + role: "assistant", + content: "", + createdAt: new Date(), + }; + + setMessages(prev => [...prev, assistantMessage]); + let fullContent = ""; + for await (const chunk of response) { + if (chunk.choices && chunk.choices[0]?.delta?.content) { + const deltaContent = chunk.choices[0].delta.content; + fullContent += deltaContent; + + flushSync(() => { + setMessages(prev => { + const newMessages = [...prev]; + const lastMessage = newMessages[newMessages.length - 1]; + if (lastMessage.role === "assistant") { + lastMessage.content = fullContent; + } + return newMessages; + }); }); - }); + } } + } catch (err) { + console.error("Error sending message:", err); + setError("Failed to send message. Please try again."); + setMessages(prev => prev.slice(0, -1)); + } finally { + setIsGenerating(false); } - } catch (err) { - console.error("Error sending message:", err); - setError("Failed to send message. Please try again."); - setMessages(prev => prev.slice(0, -1)); - } finally { - setIsGenerating(false); - } -}; + }; const suggestions = [ "Write a Python function that prints 'Hello, World!'", "Explain step-by-step how to solve this math problem: If xΒ² + 6x + 9 = 25, what is x?", @@ -163,7 +181,7 @@ const handleSubmitWithContent = async (content: string) => { content: message.content, createdAt: new Date(), }; - setMessages(prev => [...prev, newMessage]) + setMessages(prev => [...prev, newMessage]); handleSubmitWithContent(newMessage.content); }; @@ -175,14 +193,22 @@ const handleSubmitWithContent = async (content: string) => { return (
-

Chat Playground

+

Chat Playground (Completions)

- - + - {models.map((model) => ( + {models.map(model => ( {model.identifier} diff --git a/llama_stack/ui/app/logs/chat-completions/[id]/page.tsx b/llama_stack/ui/app/logs/chat-completions/[id]/page.tsx index 82aa3496e..e11924f4c 100644 --- a/llama_stack/ui/app/logs/chat-completions/[id]/page.tsx +++ b/llama_stack/ui/app/logs/chat-completions/[id]/page.tsx @@ -33,12 +33,12 @@ export default function ChatCompletionDetailPage() { } catch (err) { console.error( `Error fetching chat completion detail for ID ${id}:`, - err, + err ); setError( err instanceof Error ? err - : new Error("Failed to fetch completion detail"), + : new Error("Failed to fetch completion detail") ); } finally { setIsLoading(false); diff --git a/llama_stack/ui/app/logs/responses/[id]/page.tsx b/llama_stack/ui/app/logs/responses/[id]/page.tsx index 7f4252856..922d35531 100644 --- a/llama_stack/ui/app/logs/responses/[id]/page.tsx +++ b/llama_stack/ui/app/logs/responses/[id]/page.tsx @@ -13,10 +13,10 @@ export default function ResponseDetailPage() { const client = useAuthClient(); const [responseDetail, setResponseDetail] = useState( - null, + null ); const [inputItems, setInputItems] = useState( - null, + null ); const [isLoading, setIsLoading] = useState(true); const [isLoadingInputItems, setIsLoadingInputItems] = useState(true); @@ -25,7 +25,7 @@ export default function ResponseDetailPage() { // Helper function to convert ResponseObject to OpenAIResponse const convertResponseObject = ( - responseData: ResponseObject, + responseData: ResponseObject ): OpenAIResponse => { return { id: responseData.id, @@ -73,12 +73,12 @@ export default function ResponseDetailPage() { } else { console.error( `Error fetching response detail for ID ${id}:`, - responseResult.reason, + responseResult.reason ); setError( responseResult.reason instanceof Error ? responseResult.reason - : new Error("Failed to fetch response detail"), + : new Error("Failed to fetch response detail") ); } @@ -90,18 +90,18 @@ export default function ResponseDetailPage() { } else { console.error( `Error fetching input items for response ID ${id}:`, - inputItemsResult.reason, + inputItemsResult.reason ); setInputItemsError( inputItemsResult.reason instanceof Error ? inputItemsResult.reason - : new Error("Failed to fetch input items"), + : new Error("Failed to fetch input items") ); } } catch (err) { console.error(`Unexpected error fetching data for ID ${id}:`, err); setError( - err instanceof Error ? err : new Error("Unexpected error occurred"), + err instanceof Error ? err : new Error("Unexpected error occurred") ); } finally { setIsLoading(false); diff --git a/llama_stack/ui/app/logs/vector-stores/[id]/files/[fileId]/contents/[contentId]/page.tsx b/llama_stack/ui/app/logs/vector-stores/[id]/files/[fileId]/contents/[contentId]/page.tsx index 6896b992a..d58de3085 100644 --- a/llama_stack/ui/app/logs/vector-stores/[id]/files/[fileId]/contents/[contentId]/page.tsx +++ b/llama_stack/ui/app/logs/vector-stores/[id]/files/[fileId]/contents/[contentId]/page.tsx @@ -18,7 +18,10 @@ import { PropertiesCard, PropertyItem, } from "@/components/layout/detail-layout"; -import { PageBreadcrumb, BreadcrumbSegment } from "@/components/layout/page-breadcrumb"; +import { + PageBreadcrumb, + BreadcrumbSegment, +} from "@/components/layout/page-breadcrumb"; export default function ContentDetailPage() { const params = useParams(); @@ -28,13 +31,13 @@ export default function ContentDetailPage() { const contentId = params.contentId as string; const client = useAuthClient(); - const getTextFromContent = (content: any): string => { - if (typeof content === 'string') { + const getTextFromContent = (content: unknown): string => { + if (typeof content === "string") { return content; - } else if (content && content.type === 'text') { + } else if (content && content.type === "text") { return content.text; } - return ''; + return ""; }; const [store, setStore] = useState(null); @@ -44,7 +47,9 @@ export default function ContentDetailPage() { const [error, setError] = useState(null); const [isEditing, setIsEditing] = useState(false); const [editedContent, setEditedContent] = useState(""); - const [editedMetadata, setEditedMetadata] = useState>({}); + const [editedMetadata, setEditedMetadata] = useState>( + {} + ); const [isEditingEmbedding, setIsEditingEmbedding] = useState(false); const [editedEmbedding, setEditedEmbedding] = useState([]); @@ -64,8 +69,13 @@ export default function ContentDetailPage() { setFile(fileResponse as VectorStoreFile); const contentsAPI = new ContentsAPI(client); - const contentsResponse = await contentsAPI.listContents(vectorStoreId, fileId); - const targetContent = contentsResponse.data.find(c => c.id === contentId); + const contentsResponse = await contentsAPI.listContents( + vectorStoreId, + fileId + ); + const targetContent = contentsResponse.data.find( + c => c.id === contentId + ); if (targetContent) { setContent(targetContent); @@ -76,7 +86,9 @@ export default function ContentDetailPage() { throw new Error(`Content ${contentId} not found`); } } catch (err) { - setError(err instanceof Error ? err : new Error("Failed to load content.")); + setError( + err instanceof Error ? err : new Error("Failed to load content.") + ); } finally { setIsLoading(false); } @@ -88,7 +100,8 @@ export default function ContentDetailPage() { if (!content) return; try { - const updates: { content?: string; metadata?: Record } = {}; + const updates: { content?: string; metadata?: Record } = + {}; if (editedContent !== getTextFromContent(content.content)) { updates.content = editedContent; @@ -100,25 +113,32 @@ export default function ContentDetailPage() { if (Object.keys(updates).length > 0) { const contentsAPI = new ContentsAPI(client); - const updatedContent = await contentsAPI.updateContent(vectorStoreId, fileId, contentId, updates); + const updatedContent = await contentsAPI.updateContent( + vectorStoreId, + fileId, + contentId, + updates + ); setContent(updatedContent); } setIsEditing(false); } catch (err) { - console.error('Failed to update content:', err); + console.error("Failed to update content:", err); } }; const handleDelete = async () => { - if (!confirm('Are you sure you want to delete this content?')) return; + if (!confirm("Are you sure you want to delete this content?")) return; try { const contentsAPI = new ContentsAPI(client); await contentsAPI.deleteContent(vectorStoreId, fileId, contentId); - router.push(`/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents`); + router.push( + `/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents` + ); } catch (err) { - console.error('Failed to delete content:', err); + console.error("Failed to delete content:", err); } }; @@ -134,10 +154,19 @@ export default function ContentDetailPage() { const breadcrumbSegments: BreadcrumbSegment[] = [ { label: "Vector Stores", href: "/logs/vector-stores" }, - { label: store?.name || vectorStoreId, href: `/logs/vector-stores/${vectorStoreId}` }, + { + label: store?.name || vectorStoreId, + href: `/logs/vector-stores/${vectorStoreId}`, + }, { label: "Files", href: `/logs/vector-stores/${vectorStoreId}` }, - { label: fileId, href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}` }, - { label: "Contents", href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents` }, + { + label: fileId, + href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}`, + }, + { + label: "Contents", + href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents`, + }, { label: contentId }, ]; @@ -186,7 +215,7 @@ export default function ContentDetailPage() { {isEditing ? (