Merge branch 'main' into dependabot/uv/huggingface-hub-0.34.4

This commit is contained in:
Ashwin Bharambe 2025-08-15 10:54:28 -07:00 committed by GitHub
commit c548abcada
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
194 changed files with 12504 additions and 7718 deletions

View file

@ -28,7 +28,7 @@ runs:
# Install llama-stack-client-python based on the client-version input
if [ "${{ inputs.client-version }}" = "latest" ]; then
echo "Installing latest llama-stack-client-python from main branch"
uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main
uv pip install git+https://github.com/llamastack/llama-stack-client-python.git@main
elif [ "${{ inputs.client-version }}" = "published" ]; then
echo "Installing published llama-stack-client-python from PyPI"
uv pip install llama-stack-client

View file

@ -52,7 +52,8 @@ jobs:
run: |
# Get test directories dynamically, excluding non-test directories
# NOTE: we are excluding post_training since the tests take too long
TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d -printf "%f\n" |
TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d |
sed 's|tests/integration/||' |
grep -Ev "^(__pycache__|fixtures|test_cases|recordings|non_ci|post_training)$" |
sort | jq -R -s -c 'split("\n")[:-1]')
echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT

View file

@ -14,9 +14,11 @@ on:
- 'pyproject.toml'
- 'requirements.txt'
- '.github/workflows/integration-vector-io-tests.yml' # This workflow
schedule:
- cron: '0 0 * * *' # (test on python 3.13) Daily at 12 AM UTC
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true
jobs:
@ -25,7 +27,7 @@ jobs:
strategy:
matrix:
vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::weaviate", "remote::qdrant"]
python-version: ["3.12", "3.13"]
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
fail-fast: false # we want to run all tests regardless of failure
steps:
@ -164,9 +166,9 @@ jobs:
ENABLE_WEAVIATE: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'true' || '' }}
WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }}
run: |
uv run pytest -sv --stack-config="inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \
uv run pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \
tests/integration/vector_io \
--embedding-model sentence-transformers/all-MiniLM-L6-v2
--embedding-model inline::sentence-transformers/all-MiniLM-L6-v2
- name: Check Storage and Memory Available After Tests
if: ${{ always() }}

View file

@ -3,7 +3,7 @@ name: Integration Tests (Record)
run-name: Run the integration test suite from tests/integration
on:
pull_request:
pull_request_target:
branches: [ main ]
types: [opened, synchronize, labeled]
paths:
@ -23,7 +23,7 @@ on:
default: 'ollama'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
group: ${{ github.workflow }}-${{ github.event.pull_request.number }}
cancel-in-progress: true
jobs:

View file

@ -11,7 +11,7 @@ on:
- synchronize
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
group: ${{ github.workflow }}-${{ github.event.pull_request.number }}
cancel-in-progress: true
permissions:

View file

@ -2,6 +2,7 @@ exclude: 'build/'
default_language_version:
python: python3.12
node: "22"
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
@ -145,6 +146,20 @@ repos:
pass_filenames: false
require_serial: true
files: ^.github/workflows/.*$
- id: ui-prettier
name: Format UI code with Prettier
entry: bash -c 'cd llama_stack/ui && npm run format'
language: system
files: ^llama_stack/ui/.*\.(ts|tsx)$
pass_filenames: false
require_serial: true
- id: ui-eslint
name: Lint UI code with ESLint
entry: bash -c 'cd llama_stack/ui && npm run lint -- --fix --quiet'
language: system
files: ^llama_stack/ui/.*\.(ts|tsx)$
pass_filenames: false
require_serial: true
ci:
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks

View file

@ -1,13 +1,82 @@
# Contributing to Llama-Stack
# Contributing to Llama Stack
We want to make contributing to this project as easy and transparent as
possible.
## Set up your development environment
We use [uv](https://github.com/astral-sh/uv) to manage python dependencies and virtual environments.
You can install `uv` by following this [guide](https://docs.astral.sh/uv/getting-started/installation/).
You can install the dependencies by running:
```bash
cd llama-stack
uv sync --group dev
uv pip install -e .
source .venv/bin/activate
```
```{note}
You can use a specific version of Python with `uv` by adding the `--python <version>` flag (e.g. `--python 3.12`).
Otherwise, `uv` will automatically select a Python version according to the `requires-python` section of the `pyproject.toml`.
For more info, see the [uv docs around Python versions](https://docs.astral.sh/uv/concepts/python-versions/).
```
Note that you can create a dotenv file `.env` that includes necessary environment variables:
```
LLAMA_STACK_BASE_URL=http://localhost:8321
LLAMA_STACK_CLIENT_LOG=debug
LLAMA_STACK_PORT=8321
LLAMA_STACK_CONFIG=<provider-name>
TAVILY_SEARCH_API_KEY=
BRAVE_SEARCH_API_KEY=
```
And then use this dotenv file when running client SDK tests via the following:
```bash
uv run --env-file .env -- pytest -v tests/integration/inference/test_text_inference.py --text-model=meta-llama/Llama-3.1-8B-Instruct
```
### Pre-commit Hooks
We use [pre-commit](https://pre-commit.com/) to run linting and formatting checks on your code. You can install the pre-commit hooks by running:
```bash
uv run pre-commit install
```
After that, pre-commit hooks will run automatically before each commit.
Alternatively, if you don't want to install the pre-commit hooks, you can run the checks manually by running:
```bash
uv run pre-commit run --all-files
```
```{caution}
Before pushing your changes, make sure that the pre-commit hooks have passed successfully.
```
## Discussions -> Issues -> Pull Requests
We actively welcome your pull requests. However, please read the following. This is heavily inspired by [Ghostty](https://github.com/ghostty-org/ghostty/blob/main/CONTRIBUTING.md).
If in doubt, please open a [discussion](https://github.com/meta-llama/llama-stack/discussions); we can always convert that to an issue later.
### Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.
Meta has a [bounty program](http://facebook.com/whitehat/info) for the safe
disclosure of security bugs. In those cases, please go through the process
outlined on that page and do not file a public issue.
### Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Meta's open source projects.
Complete your CLA here: <https://code.facebook.com/cla>
**I'd like to contribute!**
If you are new to the project, start by looking at the issues tagged with "good first issue". If you're interested
@ -51,93 +120,15 @@ Please avoid picking up too many issues at once. This helps you stay focused and
Please keep pull requests (PRs) small and focused. If you have a large set of changes, consider splitting them into logically grouped, smaller PRs to facilitate review and testing.
> [!TIP]
> As a general guideline:
> - Experienced contributors should try to keep no more than 5 open PRs at a time.
> - New contributors are encouraged to have only one open PR at a time until theyre familiar with the codebase and process.
## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Meta's open source projects.
Complete your CLA here: <https://code.facebook.com/cla>
## Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.
Meta has a [bounty program](http://facebook.com/whitehat/info) for the safe
disclosure of security bugs. In those cases, please go through the process
outlined on that page and do not file a public issue.
## Set up your development environment
We use [uv](https://github.com/astral-sh/uv) to manage python dependencies and virtual environments.
You can install `uv` by following this [guide](https://docs.astral.sh/uv/getting-started/installation/).
You can install the dependencies by running:
```bash
cd llama-stack
uv sync --group dev
uv pip install -e .
source .venv/bin/activate
```{tip}
As a general guideline:
- Experienced contributors should try to keep no more than 5 open PRs at a time.
- New contributors are encouraged to have only one open PR at a time until theyre familiar with the codebase and process.
```
> [!NOTE]
> You can use a specific version of Python with `uv` by adding the `--python <version>` flag (e.g. `--python 3.12`)
> Otherwise, `uv` will automatically select a Python version according to the `requires-python` section of the `pyproject.toml`.
> For more info, see the [uv docs around Python versions](https://docs.astral.sh/uv/concepts/python-versions/).
## Repository guidelines
Note that you can create a dotenv file `.env` that includes necessary environment variables:
```
LLAMA_STACK_BASE_URL=http://localhost:8321
LLAMA_STACK_CLIENT_LOG=debug
LLAMA_STACK_PORT=8321
LLAMA_STACK_CONFIG=<provider-name>
TAVILY_SEARCH_API_KEY=
BRAVE_SEARCH_API_KEY=
```
And then use this dotenv file when running client SDK tests via the following:
```bash
uv run --env-file .env -- pytest -v tests/integration/inference/test_text_inference.py --text-model=meta-llama/Llama-3.1-8B-Instruct
```
## Pre-commit Hooks
We use [pre-commit](https://pre-commit.com/) to run linting and formatting checks on your code. You can install the pre-commit hooks by running:
```bash
uv run pre-commit install
```
After that, pre-commit hooks will run automatically before each commit.
Alternatively, if you don't want to install the pre-commit hooks, you can run the checks manually by running:
```bash
uv run pre-commit run --all-files
```
> [!CAUTION]
> Before pushing your changes, make sure that the pre-commit hooks have passed successfully.
## Running tests
You can find the Llama Stack testing documentation [here](https://github.com/meta-llama/llama-stack/blob/main/tests/README.md).
## Adding a new dependency to the project
To add a new dependency to the project, you can use the `uv` command. For example, to add `foo` to the project, you can run:
```bash
uv add foo
uv sync
```
## Coding Style
### Coding Style
* Comments should provide meaningful insights into the code. Avoid filler comments that simply
describe the next step, as they create unnecessary clutter, same goes for docstrings.
@ -159,6 +150,10 @@ uv sync
* When possible, use keyword arguments only when calling functions.
* Llama Stack utilizes [custom Exception classes](llama_stack/apis/common/errors.py) for certain Resources that should be used where applicable.
### License
By contributing to Llama, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.
## Common Tasks
Some tips about common tasks you work on while contributing to Llama Stack:
@ -211,7 +206,3 @@ uv run ./docs/openapi_generator/run_openapi_generator.sh
```
The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing.
## License
By contributing to Llama, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.

View file

@ -9,6 +9,7 @@
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) | [**Discord**](https://discord.gg/llama-stack)
### ✨🎉 Llama 4 Support 🎉✨
We released [Version 0.2.0](https://github.com/meta-llama/llama-stack/releases/tag/v0.2.0) with support for the Llama 4 herd of models released by Meta.
@ -179,3 +180,17 @@ Please checkout our [Documentation](https://llama-stack.readthedocs.io/en/latest
Check out our client SDKs for connecting to a Llama Stack server in your preferred language, you can choose from [python](https://github.com/meta-llama/llama-stack-client-python), [typescript](https://github.com/meta-llama/llama-stack-client-typescript), [swift](https://github.com/meta-llama/llama-stack-client-swift), and [kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) programming languages to quickly build your applications.
You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/examples) repo.
## 🌟 GitHub Star History
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=meta-llama/llama-stack&type=Date)](https://www.star-history.com/#meta-llama/llama-stack&Date)
## ✨ Contributors
Thanks to all of our amazing contributors!
<a href="https://github.com/meta-llama/llama-stack/graphs/contributors">
<img src="https://contrib.rocks/image?repo=meta-llama/llama-stack" />
</a>

14
docs/_static/js/keyboard_shortcuts.js vendored Normal file
View file

@ -0,0 +1,14 @@
document.addEventListener('keydown', function(event) {
// command+K or ctrl+K
if ((event.metaKey || event.ctrlKey) && event.key === 'k') {
event.preventDefault();
document.querySelector('.search-input, .search-field, input[name="q"]').focus();
}
// forward slash
if (event.key === '/' &&
!event.target.matches('input, textarea, select')) {
event.preventDefault();
document.querySelector('.search-input, .search-field, input[name="q"]').focus();
}
});

View file

@ -8292,6 +8292,9 @@
"results": {
"type": "array",
"items": {
"type": "object",
"properties": {
"attributes": {
"type": "object",
"additionalProperties": {
"oneOf": [
@ -8314,8 +8317,37 @@
"type": "object"
}
]
},
"description": "(Optional) Key-value attributes associated with the file"
},
"file_id": {
"type": "string",
"description": "Unique identifier of the file containing the result"
},
"filename": {
"type": "string",
"description": "Name of the file containing the result"
},
"score": {
"type": "number",
"description": "Relevance score for this search result (between 0 and 1)"
},
"text": {
"type": "string",
"description": "Text content of the search result"
}
},
"additionalProperties": false,
"required": [
"attributes",
"file_id",
"filename",
"score",
"text"
],
"title": "OpenAIResponseOutputMessageFileSearchToolCallResults",
"description": "Search results returned by the file search operation."
},
"description": "(Optional) Search results returned by the file search operation"
}
},
@ -8515,6 +8547,13 @@
"$ref": "#/components/schemas/OpenAIResponseInputTool"
}
},
"include": {
"type": "array",
"items": {
"type": "string"
},
"description": "(Optional) Additional fields to include in the response."
},
"max_infer_iters": {
"type": "integer"
}
@ -8782,6 +8821,61 @@
"title": "OpenAIResponseOutputMessageMCPListTools",
"description": "MCP list tools output message containing available tools from an MCP server."
},
"OpenAIResponseContentPart": {
"oneOf": [
{
"$ref": "#/components/schemas/OpenAIResponseContentPartOutputText"
},
{
"$ref": "#/components/schemas/OpenAIResponseContentPartRefusal"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"output_text": "#/components/schemas/OpenAIResponseContentPartOutputText",
"refusal": "#/components/schemas/OpenAIResponseContentPartRefusal"
}
}
},
"OpenAIResponseContentPartOutputText": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "output_text",
"default": "output_text"
},
"text": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"type",
"text"
],
"title": "OpenAIResponseContentPartOutputText"
},
"OpenAIResponseContentPartRefusal": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "refusal",
"default": "refusal"
},
"refusal": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"type",
"refusal"
],
"title": "OpenAIResponseContentPartRefusal"
},
"OpenAIResponseObjectStream": {
"oneOf": [
{
@ -8838,6 +8932,12 @@
{
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallCompleted"
},
{
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseContentPartAdded"
},
{
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseContentPartDone"
},
{
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted"
}
@ -8863,6 +8963,8 @@
"response.mcp_call.in_progress": "#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallInProgress",
"response.mcp_call.failed": "#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallFailed",
"response.mcp_call.completed": "#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallCompleted",
"response.content_part.added": "#/components/schemas/OpenAIResponseObjectStreamResponseContentPartAdded",
"response.content_part.done": "#/components/schemas/OpenAIResponseObjectStreamResponseContentPartDone",
"response.completed": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted"
}
}
@ -8889,6 +8991,80 @@
"title": "OpenAIResponseObjectStreamResponseCompleted",
"description": "Streaming event indicating a response has been completed."
},
"OpenAIResponseObjectStreamResponseContentPartAdded": {
"type": "object",
"properties": {
"response_id": {
"type": "string",
"description": "Unique identifier of the response containing this content"
},
"item_id": {
"type": "string",
"description": "Unique identifier of the output item containing this content part"
},
"part": {
"$ref": "#/components/schemas/OpenAIResponseContentPart",
"description": "The content part that was added"
},
"sequence_number": {
"type": "integer",
"description": "Sequential number for ordering streaming events"
},
"type": {
"type": "string",
"const": "response.content_part.added",
"default": "response.content_part.added",
"description": "Event type identifier, always \"response.content_part.added\""
}
},
"additionalProperties": false,
"required": [
"response_id",
"item_id",
"part",
"sequence_number",
"type"
],
"title": "OpenAIResponseObjectStreamResponseContentPartAdded",
"description": "Streaming event for when a new content part is added to a response item."
},
"OpenAIResponseObjectStreamResponseContentPartDone": {
"type": "object",
"properties": {
"response_id": {
"type": "string",
"description": "Unique identifier of the response containing this content"
},
"item_id": {
"type": "string",
"description": "Unique identifier of the output item containing this content part"
},
"part": {
"$ref": "#/components/schemas/OpenAIResponseContentPart",
"description": "The completed content part"
},
"sequence_number": {
"type": "integer",
"description": "Sequential number for ordering streaming events"
},
"type": {
"type": "string",
"const": "response.content_part.done",
"default": "response.content_part.done",
"description": "Event type identifier, always \"response.content_part.done\""
}
},
"additionalProperties": false,
"required": [
"response_id",
"item_id",
"part",
"sequence_number",
"type"
],
"title": "OpenAIResponseObjectStreamResponseContentPartDone",
"description": "Streaming event for when a content part is completed."
},
"OpenAIResponseObjectStreamResponseCreated": {
"type": "object",
"properties": {
@ -16530,7 +16706,7 @@
"additionalProperties": {
"type": "number"
},
"description": "A list of the categories along with their scores as predicted by model. Required set of categories that need to be in response - violence - violence/graphic - harassment - harassment/threatening - hate - hate/threatening - illicit - illicit/violent - sexual - sexual/minors - self-harm - self-harm/intent - self-harm/instructions"
"description": "A list of the categories along with their scores as predicted by model."
},
"user_message": {
"type": "string"

View file

@ -6020,6 +6020,9 @@ components:
results:
type: array
items:
type: object
properties:
attributes:
type: object
additionalProperties:
oneOf:
@ -6029,6 +6032,33 @@ components:
- type: string
- type: array
- type: object
description: >-
(Optional) Key-value attributes associated with the file
file_id:
type: string
description: >-
Unique identifier of the file containing the result
filename:
type: string
description: Name of the file containing the result
score:
type: number
description: >-
Relevance score for this search result (between 0 and 1)
text:
type: string
description: Text content of the search result
additionalProperties: false
required:
- attributes
- file_id
- filename
- score
- text
title: >-
OpenAIResponseOutputMessageFileSearchToolCallResults
description: >-
Search results returned by the file search operation.
description: >-
(Optional) Search results returned by the file search operation
additionalProperties: false
@ -6188,6 +6218,12 @@ components:
type: array
items:
$ref: '#/components/schemas/OpenAIResponseInputTool'
include:
type: array
items:
type: string
description: >-
(Optional) Additional fields to include in the response.
max_infer_iters:
type: integer
additionalProperties: false
@ -6405,6 +6441,43 @@ components:
title: OpenAIResponseOutputMessageMCPListTools
description: >-
MCP list tools output message containing available tools from an MCP server.
OpenAIResponseContentPart:
oneOf:
- $ref: '#/components/schemas/OpenAIResponseContentPartOutputText'
- $ref: '#/components/schemas/OpenAIResponseContentPartRefusal'
discriminator:
propertyName: type
mapping:
output_text: '#/components/schemas/OpenAIResponseContentPartOutputText'
refusal: '#/components/schemas/OpenAIResponseContentPartRefusal'
OpenAIResponseContentPartOutputText:
type: object
properties:
type:
type: string
const: output_text
default: output_text
text:
type: string
additionalProperties: false
required:
- type
- text
title: OpenAIResponseContentPartOutputText
OpenAIResponseContentPartRefusal:
type: object
properties:
type:
type: string
const: refusal
default: refusal
refusal:
type: string
additionalProperties: false
required:
- type
- refusal
title: OpenAIResponseContentPartRefusal
OpenAIResponseObjectStream:
oneOf:
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
@ -6425,6 +6498,8 @@ components:
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallInProgress'
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallFailed'
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallCompleted'
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseContentPartAdded'
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseContentPartDone'
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted'
discriminator:
propertyName: type
@ -6447,6 +6522,8 @@ components:
response.mcp_call.in_progress: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallInProgress'
response.mcp_call.failed: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallFailed'
response.mcp_call.completed: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallCompleted'
response.content_part.added: '#/components/schemas/OpenAIResponseObjectStreamResponseContentPartAdded'
response.content_part.done: '#/components/schemas/OpenAIResponseObjectStreamResponseContentPartDone'
response.completed: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted'
"OpenAIResponseObjectStreamResponseCompleted":
type: object
@ -6468,6 +6545,76 @@ components:
OpenAIResponseObjectStreamResponseCompleted
description: >-
Streaming event indicating a response has been completed.
"OpenAIResponseObjectStreamResponseContentPartAdded":
type: object
properties:
response_id:
type: string
description: >-
Unique identifier of the response containing this content
item_id:
type: string
description: >-
Unique identifier of the output item containing this content part
part:
$ref: '#/components/schemas/OpenAIResponseContentPart'
description: The content part that was added
sequence_number:
type: integer
description: >-
Sequential number for ordering streaming events
type:
type: string
const: response.content_part.added
default: response.content_part.added
description: >-
Event type identifier, always "response.content_part.added"
additionalProperties: false
required:
- response_id
- item_id
- part
- sequence_number
- type
title: >-
OpenAIResponseObjectStreamResponseContentPartAdded
description: >-
Streaming event for when a new content part is added to a response item.
"OpenAIResponseObjectStreamResponseContentPartDone":
type: object
properties:
response_id:
type: string
description: >-
Unique identifier of the response containing this content
item_id:
type: string
description: >-
Unique identifier of the output item containing this content part
part:
$ref: '#/components/schemas/OpenAIResponseContentPart'
description: The completed content part
sequence_number:
type: integer
description: >-
Sequential number for ordering streaming events
type:
type: string
const: response.content_part.done
default: response.content_part.done
description: >-
Event type identifier, always "response.content_part.done"
additionalProperties: false
required:
- response_id
- item_id
- part
- sequence_number
- type
title: >-
OpenAIResponseObjectStreamResponseContentPartDone
description: >-
Streaming event for when a content part is completed.
"OpenAIResponseObjectStreamResponseCreated":
type: object
properties:
@ -12286,10 +12433,6 @@ components:
type: number
description: >-
A list of the categories along with their scores as predicted by model.
Required set of categories that need to be in response - violence - violence/graphic
- harassment - harassment/threatening - hate - hate/threatening - illicit
- illicit/violent - sexual - sexual/minors - self-harm - self-harm/intent
- self-harm/instructions
user_message:
type: string
metadata:

View file

@ -111,7 +111,7 @@ name = "llama-stack-api-weather"
version = "0.1.0"
description = "Weather API for Llama Stack"
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.12"
dependencies = ["llama-stack", "pydantic"]
[build-system]
@ -231,7 +231,7 @@ name = "llama-stack-provider-kaze"
version = "0.1.0"
description = "Kaze weather provider for Llama Stack"
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.12"
dependencies = ["llama-stack", "pydantic", "aiohttp"]
[build-system]

View file

@ -2,7 +2,9 @@
Llama Stack (LLS) provides two different APIs for building AI applications with tool calling capabilities: the **Agents API** and the **OpenAI Responses API**. While both enable AI systems to use tools, and maintain full conversation history, they serve different use cases and have distinct characteristics.
> **Note:** For simple and basic inferencing, you may want to use the [Chat Completions API](https://llama-stack.readthedocs.io/en/latest/providers/index.html#chat-completions) directly, before progressing to Agents or Responses API.
```{note}
For simple and basic inferencing, you may want to use the [Chat Completions API](https://llama-stack.readthedocs.io/en/latest/providers/index.html#chat-completions) directly, before progressing to Agents or Responses API.
```
## Overview

View file

@ -76,7 +76,9 @@ Features:
- Context retrieval with token limits
> **Note:** By default, llama stack run.yaml defines toolgroups for web search, wolfram alpha and rag, that are provided by tavily-search, wolfram-alpha and rag providers.
```{note}
By default, llama stack run.yaml defines toolgroups for web search, wolfram alpha and rag, that are provided by tavily-search, wolfram-alpha and rag providers.
```
## Model Context Protocol (MCP)

View file

@ -131,6 +131,7 @@ html_static_path = ["../_static"]
def setup(app):
app.add_css_file("css/my_theme.css")
app.add_js_file("js/detect_theme.js")
app.add_js_file("js/keyboard_shortcuts.js")
def dockerhub_role(name, rawtext, text, lineno, inliner, options={}, content=[]):
url = f"https://hub.docker.com/r/llamastack/{text}"

View file

@ -2,14 +2,33 @@
```{include} ../../../CONTRIBUTING.md
```
See the [Adding a New API Provider](new_api_provider.md) which describes how to add new API providers to the Stack.
## Adding a New Provider
See:
- [Adding a New API Provider Page](new_api_provider.md) which describes how to add new API providers to the Stack.
- [Vector Database Page](new_vector_database.md) which describes how to add a new vector databases with Llama Stack.
- [External Provider Page](../providers/external/index.md) which describes how to add external providers to the Stack.
```{toctree}
:maxdepth: 1
:hidden:
new_api_provider
testing
new_vector_database
```
## Testing
```{include} ../../../tests/README.md
```
### Advanced Topics
For developers who need deeper understanding of the testing system internals:
```{toctree}
:maxdepth: 1
testing/record-replay
```

View file

@ -0,0 +1,75 @@
# Adding a New Vector Database
This guide will walk you through the process of adding a new vector database to Llama Stack.
> **_NOTE:_** Here's an example Pull Request of the [Milvus Vector Database Provider](https://github.com/meta-llama/llama-stack/pull/1467).
Vector Database providers are used to store and retrieve vector embeddings. Vector databases are not limited to vector
search but can support keyword and hybrid search. Additionally, vector database can also support operations like
filtering, sorting, and aggregating vectors.
## Steps to Add a New Vector Database Provider
1. **Choose the Database Type**: Determine if your vector database is a remote service, inline, or both.
- Remote databases make requests to external services, while inline databases execute locally. Some providers support both.
2. **Implement the Provider**: Create a new provider class that inherits from `VectorDatabaseProvider` and implements the required methods.
- Implement methods for vector storage, retrieval, search, and any additional features your database supports.
- You will need to implement the following methods for `YourVectorIndex`:
- `YourVectorIndex.create()`
- `YourVectorIndex.initialize()`
- `YourVectorIndex.add_chunks()`
- `YourVectorIndex.delete_chunk()`
- `YourVectorIndex.query_vector()`
- `YourVectorIndex.query_keyword()`
- `YourVectorIndex.query_hybrid()`
- You will need to implement the following methods for `YourVectorIOAdapter`:
- `YourVectorIOAdapter.initialize()`
- `YourVectorIOAdapter.shutdown()`
- `YourVectorIOAdapter.list_vector_dbs()`
- `YourVectorIOAdapter.register_vector_db()`
- `YourVectorIOAdapter.unregister_vector_db()`
- `YourVectorIOAdapter.insert_chunks()`
- `YourVectorIOAdapter.query_chunks()`
- `YourVectorIOAdapter.delete_chunks()`
3. **Add to Registry**: Register your provider in the appropriate registry file.
- Update {repopath}`llama_stack/providers/registry/vector_io.py` to include your new provider.
```python
from llama_stack.providers.registry.specs import InlineProviderSpec
from llama_stack.providers.registry.api import Api
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::milvus",
pip_packages=["pymilvus>=2.4.10"],
module="llama_stack.providers.inline.vector_io.milvus",
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description="",
),
```
4. **Add Tests**: Create unit tests and integration tests for your provider in the `tests/` directory.
- Unit Tests
- By following the structure of the class methods, you will be able to easily run unit and integration tests for your database.
1. You have to configure the tests for your provide in `/tests/unit/providers/vector_io/conftest.py`.
2. Update the `vector_provider` fixture to include your provider if they are an inline provider.
3. Create a `your_vectorprovider_index` fixture that initializes your vector index.
4. Create a `your_vectorprovider_adapter` fixture that initializes your vector adapter.
5. Add your provider to the `vector_io_providers` fixture dictionary.
- Please follow the naming convention of `your_vectorprovider_index` and `your_vectorprovider_adapter` as the tests require this to execute properly.
- Integration Tests
- Integration tests are located in {repopath}`tests/integration`. These tests use the python client-SDK APIs (from the `llama_stack_client` package) to test functionality.
- The two set of integration tests are:
- `tests/integration/vector_io/test_vector_io.py`: This file tests registration, insertion, and retrieval.
- `tests/integration/vector_io/test_openai_vector_stores.py`: These tests are for OpenAI-compatible vector stores and test the OpenAI API compatibility.
- You will need to update `skip_if_provider_doesnt_support_openai_vector_stores` to include your provider as well as `skip_if_provider_doesnt_support_openai_vector_stores_search` to test the appropriate search functionality.
- Running the tests in the GitHub CI
- You will need to update the `.github/workflows/integration-vector-io-tests.yml` file to include your provider.
- If your provider is a remote provider, you will also have to add a container to spin up and run it in the action.
- Updating the pyproject.yml
- If you are adding tests for the `inline` provider you will have to update the `unit` group.
- `uv add new_pip_package --group unit`
- If you are adding tests for the `remote` provider you will have to update the `test` group, which is used in the GitHub CI for integration tests.
- `uv add new_pip_package --group test`
5. **Update Documentation**: Please update the documentation for end users
- Generate the provider documentation by running {repopath}`./scripts/provider_codegen.py`.
- Update the autogenerated content in the registry/vector_io.py file with information about your provider. Please see other providers for examples.

View file

@ -1,6 +0,0 @@
# Testing Llama Stack
Tests are of three different kinds:
- Unit tests
- Provider focused integration tests
- Client SDK tests

View file

@ -0,0 +1,234 @@
# Record-Replay System
Understanding how Llama Stack captures and replays API interactions for testing.
## Overview
The record-replay system solves a fundamental challenge in AI testing: how do you test against expensive, non-deterministic APIs without breaking the bank or dealing with flaky tests?
The solution: intercept API calls, store real responses, and replay them later. This gives you real API behavior without the cost or variability.
## How It Works
### Request Hashing
Every API request gets converted to a deterministic hash for lookup:
```python
def normalize_request(method: str, url: str, headers: dict, body: dict) -> str:
normalized = {
"method": method.upper(),
"endpoint": urlparse(url).path, # Just the path, not full URL
"body": body, # Request parameters
}
return hashlib.sha256(json.dumps(normalized, sort_keys=True).encode()).hexdigest()
```
**Key insight:** The hashing is intentionally precise. Different whitespace, float precision, or parameter order produces different hashes. This prevents subtle bugs from false cache hits.
```python
# These produce DIFFERENT hashes:
{"content": "Hello world"}
{"content": "Hello world\n"}
{"temperature": 0.7}
{"temperature": 0.7000001}
```
### Client Interception
The system patches OpenAI and Ollama client methods to intercept calls before they leave your application. This happens transparently - your test code doesn't change.
### Storage Architecture
Recordings use a two-tier storage system optimized for both speed and debuggability:
```
recordings/
├── index.sqlite # Fast lookup by request hash
└── responses/
├── abc123def456.json # Individual response files
└── def789ghi012.json
```
**SQLite index** enables O(log n) hash lookups and metadata queries without loading response bodies.
**JSON files** store complete request/response pairs in human-readable format for debugging.
## Recording Modes
### LIVE Mode
Direct API calls with no recording or replay:
```python
with inference_recording(mode=InferenceMode.LIVE):
response = await client.chat.completions.create(...)
```
Use for initial development and debugging against real APIs.
### RECORD Mode
Captures API interactions while passing through real responses:
```python
with inference_recording(mode=InferenceMode.RECORD, storage_dir="./recordings"):
response = await client.chat.completions.create(...)
# Real API call made, response captured AND returned
```
The recording process:
1. Request intercepted and hashed
2. Real API call executed
3. Response captured and serialized
4. Recording stored to disk
5. Original response returned to caller
### REPLAY Mode
Returns stored responses instead of making API calls:
```python
with inference_recording(mode=InferenceMode.REPLAY, storage_dir="./recordings"):
response = await client.chat.completions.create(...)
# No API call made, cached response returned instantly
```
The replay process:
1. Request intercepted and hashed
2. Hash looked up in SQLite index
3. Response loaded from JSON file
4. Response deserialized and returned
5. Error if no recording found
## Streaming Support
Streaming APIs present a unique challenge: how do you capture an async generator?
### The Problem
```python
# How do you record this?
async for chunk in client.chat.completions.create(stream=True):
process(chunk)
```
### The Solution
The system captures all chunks immediately before yielding any:
```python
async def handle_streaming_record(response):
# Capture complete stream first
chunks = []
async for chunk in response:
chunks.append(chunk)
# Store complete recording
storage.store_recording(
request_hash, request_data, {"body": chunks, "is_streaming": True}
)
# Return generator that replays captured chunks
async def replay_stream():
for chunk in chunks:
yield chunk
return replay_stream()
```
This ensures:
- **Complete capture** - The entire stream is saved atomically
- **Interface preservation** - The returned object behaves like the original API
- **Deterministic replay** - Same chunks in the same order every time
## Serialization
API responses contain complex Pydantic objects that need careful serialization:
```python
def _serialize_response(response):
if hasattr(response, "model_dump"):
# Preserve type information for proper deserialization
return {
"__type__": f"{response.__class__.__module__}.{response.__class__.__qualname__}",
"__data__": response.model_dump(mode="json"),
}
return response
```
This preserves type safety - when replayed, you get the same Pydantic objects with all their validation and methods.
## Environment Integration
### Environment Variables
Control recording behavior globally:
```bash
export LLAMA_STACK_TEST_INFERENCE_MODE=replay
export LLAMA_STACK_TEST_RECORDING_DIR=/path/to/recordings
pytest tests/integration/
```
### Pytest Integration
The system integrates automatically based on environment variables, requiring no changes to test code.
## Debugging Recordings
### Inspecting Storage
```bash
# See what's recorded
sqlite3 recordings/index.sqlite "SELECT endpoint, model, timestamp FROM recordings LIMIT 10;"
# View specific response
cat recordings/responses/abc123def456.json | jq '.response.body'
# Find recordings by endpoint
sqlite3 recordings/index.sqlite "SELECT * FROM recordings WHERE endpoint='/v1/chat/completions';"
```
### Common Issues
**Hash mismatches:** Request parameters changed slightly between record and replay
```bash
# Compare request details
cat recordings/responses/abc123.json | jq '.request'
```
**Serialization errors:** Response types changed between versions
```bash
# Re-record with updated types
rm recordings/responses/failing_hash.json
LLAMA_STACK_TEST_INFERENCE_MODE=record pytest test_failing.py
```
**Missing recordings:** New test or changed parameters
```bash
# Record the missing interaction
LLAMA_STACK_TEST_INFERENCE_MODE=record pytest test_new.py
```
## Design Decisions
### Why Not Mocks?
Traditional mocking breaks down with AI APIs because:
- Response structures are complex and evolve frequently
- Streaming behavior is hard to mock correctly
- Edge cases in real APIs get missed
- Mocks become brittle maintenance burdens
### Why Precise Hashing?
Loose hashing (normalizing whitespace, rounding floats) seems convenient but hides bugs. If a test changes slightly, you want to know about it rather than accidentally getting the wrong cached response.
### Why JSON + SQLite?
- **JSON** - Human readable, diff-friendly, easy to inspect and modify
- **SQLite** - Fast indexed lookups without loading response bodies
- **Hybrid** - Best of both worlds for different use cases
This system provides reliable, fast testing against real AI APIs while maintaining the ability to debug issues when they arise.

View file

@ -0,0 +1,57 @@
#!/usr/bin/env bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Deploys the benchmark-specific components on top of the base k8s deployment (../k8s/apply.sh).
export MOCK_INFERENCE_PORT=8080
export STREAM_DELAY_SECONDS=0.005
export POSTGRES_USER=llamastack
export POSTGRES_DB=llamastack
export POSTGRES_PASSWORD=llamastack
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
export MOCK_INFERENCE_MODEL=mock-inference
# Use llama-stack-benchmark-service as the benchmark server
export LOCUST_HOST=http://llama-stack-benchmark-service:8323
export LOCUST_BASE_PATH=/v1/openai/v1
# Use vllm-service as the benchmark server
# export LOCUST_HOST=http://vllm-server:8000
# export LOCUST_BASE_PATH=/v1
export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL
set -euo pipefail
set -x
# Deploy benchmark-specific components
# Deploy OpenAI mock server
kubectl create configmap openai-mock --from-file=openai-mock-server.py \
--dry-run=client -o yaml | kubectl apply --validate=false -f -
envsubst < openai-mock-deployment.yaml | kubectl apply --validate=false -f -
# Create configmap with our custom stack config
kubectl create configmap llama-stack-config --from-file=stack_run_config.yaml \
--dry-run=client -o yaml > stack-configmap.yaml
kubectl apply --validate=false -f stack-configmap.yaml
# Deploy our custom llama stack server (overriding the base one)
envsubst < stack-k8s.yaml.template | kubectl apply --validate=false -f -
# Deploy Locust load testing
kubectl create configmap locust-script --from-file=locustfile.py \
--dry-run=client -o yaml | kubectl apply --validate=false -f -
envsubst < locust-k8s.yaml | kubectl apply --validate=false -f -

View file

@ -0,0 +1,131 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: locust-master
labels:
app: locust
role: master
spec:
replicas: 1
selector:
matchLabels:
app: locust
role: master
template:
metadata:
labels:
app: locust
role: master
spec:
containers:
- name: locust-master
image: locustio/locust:2.31.8
ports:
- containerPort: 8089 # Web UI
- containerPort: 5557 # Master communication
env:
- name: LOCUST_HOST
value: "${LOCUST_HOST}"
- name: LOCUST_LOCUSTFILE
value: "/locust/locustfile.py"
- name: LOCUST_WEB_HOST
value: "0.0.0.0"
- name: LOCUST_MASTER
value: "true"
- name: LOCUST_BASE_PATH
value: "${LOCUST_BASE_PATH}"
- name: INFERENCE_MODEL
value: "${BENCHMARK_INFERENCE_MODEL}"
volumeMounts:
- name: locust-script
mountPath: /locust
command: ["locust"]
args:
- "--master"
- "--web-host=0.0.0.0"
- "--web-port=8089"
- "--host=${LOCUST_HOST}"
- "--locustfile=/locust/locustfile.py"
volumes:
- name: locust-script
configMap:
name: locust-script
---
apiVersion: apps/v1
kind: Deployment
metadata:
name: locust-worker
labels:
app: locust
role: worker
spec:
replicas: 2 # Start with 2 workers, can be scaled up
selector:
matchLabels:
app: locust
role: worker
template:
metadata:
labels:
app: locust
role: worker
spec:
containers:
- name: locust-worker
image: locustio/locust:2.31.8
env:
- name: LOCUST_HOST
value: "${LOCUST_HOST}"
- name: LOCUST_LOCUSTFILE
value: "/locust/locustfile.py"
- name: LOCUST_MASTER_HOST
value: "locust-master-service"
- name: LOCUST_MASTER_PORT
value: "5557"
- name: INFERENCE_MODEL
value: "${BENCHMARK_INFERENCE_MODEL}"
- name: LOCUST_BASE_PATH
value: "${LOCUST_BASE_PATH}"
volumeMounts:
- name: locust-script
mountPath: /locust
command: ["locust"]
args:
- "--worker"
- "--master-host=locust-master-service"
- "--master-port=5557"
- "--locustfile=/locust/locustfile.py"
volumes:
- name: locust-script
configMap:
name: locust-script
---
apiVersion: v1
kind: Service
metadata:
name: locust-master-service
spec:
selector:
app: locust
role: master
ports:
- name: web-ui
port: 8089
targetPort: 8089
- name: master-comm
port: 5557
targetPort: 5557
type: ClusterIP
---
apiVersion: v1
kind: Service
metadata:
name: locust-web-ui
spec:
selector:
app: locust
role: master
ports:
- port: 8089
targetPort: 8089
type: ClusterIP # Keep internal, use port-forward to access

View file

@ -0,0 +1,78 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Locust load testing script for Llama Stack with Prism mock OpenAI provider.
"""
import random
from locust import HttpUser, task, between
import os
base_path = os.getenv("LOCUST_BASE_PATH", "/v1/openai/v1")
MODEL_ID = os.getenv("INFERENCE_MODEL")
class LlamaStackUser(HttpUser):
wait_time = between(0.0, 0.0001)
def on_start(self):
"""Setup authentication and test data."""
# No auth required for benchmark server
self.headers = {
"Content-Type": "application/json"
}
# Test messages of varying lengths
self.test_messages = [
[{"role": "user", "content": "Hi"}],
[{"role": "user", "content": "What is the capital of France?"}],
[{"role": "user", "content": "Explain quantum physics in simple terms."}],
[{"role": "user", "content": "Write a short story about a robot learning to paint."}],
[
{"role": "user", "content": "What is machine learning?"},
{"role": "assistant", "content": "Machine learning is a subset of AI..."},
{"role": "user", "content": "Can you give me a practical example?"}
]
]
@task(weight=100)
def chat_completion_streaming(self):
"""Test streaming chat completion (20% of requests)."""
messages = random.choice(self.test_messages)
payload = {
"model": MODEL_ID,
"messages": messages,
"stream": True,
"max_tokens": 100
}
with self.client.post(
f"{base_path}/chat/completions",
headers=self.headers,
json=payload,
stream=True,
catch_response=True
) as response:
if response.status_code == 200:
chunks_received = 0
try:
for line in response.iter_lines():
if line:
line_str = line.decode('utf-8')
if line_str.startswith('data: '):
chunks_received += 1
if line_str.strip() == 'data: [DONE]':
break
if chunks_received > 0:
response.success()
else:
response.failure("No streaming chunks received")
except Exception as e:
response.failure(f"Streaming error: {e}")
else:
response.failure(f"HTTP {response.status_code}: {response.text}")

View file

@ -0,0 +1,52 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: openai-mock
labels:
app: openai-mock
spec:
replicas: 1
selector:
matchLabels:
app: openai-mock
template:
metadata:
labels:
app: openai-mock
spec:
containers:
- name: openai-mock
image: python:3.12-slim
ports:
- containerPort: ${MOCK_INFERENCE_PORT}
env:
- name: PORT
value: "${MOCK_INFERENCE_PORT}"
- name: MOCK_MODELS
value: "${MOCK_INFERENCE_MODEL}"
- name: STREAM_DELAY_SECONDS
value: "${STREAM_DELAY_SECONDS}"
command: ["sh", "-c"]
args:
- |
pip install flask &&
python /app/openai-mock-server.py --port ${MOCK_INFERENCE_PORT}
volumeMounts:
- name: openai-mock-script
mountPath: /app
volumes:
- name: openai-mock-script
configMap:
name: openai-mock
---
apiVersion: v1
kind: Service
metadata:
name: openai-mock-service
spec:
selector:
app: openai-mock
ports:
- port: 8080
targetPort: 8080
type: ClusterIP

View file

@ -0,0 +1,190 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
OpenAI-compatible mock server that returns:
- Hardcoded /models response for consistent validation
- Valid OpenAI-formatted chat completion responses with dynamic content
"""
from flask import Flask, request, jsonify, Response
import time
import random
import uuid
import json
import argparse
import os
app = Flask(__name__)
# Models from environment variables
def get_models():
models_str = os.getenv("MOCK_MODELS", "mock-inference")
model_ids = [m.strip() for m in models_str.split(",") if m.strip()]
return {
"object": "list",
"data": [
{
"id": model_id,
"object": "model",
"created": 1234567890,
"owned_by": "vllm"
}
for model_id in model_ids
]
}
def generate_random_text(length=50):
"""Generate random but coherent text for responses."""
words = [
"Hello", "there", "I'm", "an", "AI", "assistant", "ready", "to", "help", "you",
"with", "your", "questions", "and", "tasks", "today", "Let", "me","know", "what",
"you'd", "like", "to", "discuss", "or", "explore", "together", "I", "can", "assist",
"with", "various", "topics", "including", "coding", "writing", "analysis", "and", "more"
]
return " ".join(random.choices(words, k=length))
@app.route('/models', methods=['GET'])
def list_models():
models = get_models()
print(f"[MOCK] Returning models: {[m['id'] for m in models['data']]}")
return jsonify(models)
@app.route('/chat/completions', methods=['POST'])
def chat_completions():
"""Return OpenAI-formatted chat completion responses."""
data = request.get_json()
default_model = get_models()['data'][0]['id']
model = data.get('model', default_model)
messages = data.get('messages', [])
stream = data.get('stream', False)
print(f"[MOCK] Chat completion request - model: {model}, stream: {stream}")
if stream:
return handle_streaming_completion(model, messages)
else:
return handle_non_streaming_completion(model, messages)
def handle_non_streaming_completion(model, messages):
response_text = generate_random_text(random.randint(20, 80))
# Calculate realistic token counts
prompt_tokens = sum(len(str(msg.get('content', '')).split()) for msg in messages)
completion_tokens = len(response_text.split())
response = {
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": response_text
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens
}
}
return jsonify(response)
def handle_streaming_completion(model, messages):
def generate_stream():
# Generate response text
full_response = generate_random_text(random.randint(30, 100))
words = full_response.split()
# Send initial chunk
initial_chunk = {
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": ""}
}
]
}
yield f"data: {json.dumps(initial_chunk)}\n\n"
# Send word by word
for i, word in enumerate(words):
chunk = {
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": f"{word} " if i < len(words) - 1 else word}
}
]
}
yield f"data: {json.dumps(chunk)}\n\n"
# Configurable delay to simulate realistic streaming
stream_delay = float(os.getenv("STREAM_DELAY_SECONDS", "0.005"))
time.sleep(stream_delay)
# Send final chunk
final_chunk = {
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": ""},
"finish_reason": "stop"
}
]
}
yield f"data: {json.dumps(final_chunk)}\n\n"
yield "data: [DONE]\n\n"
return Response(
generate_stream(),
mimetype='text/event-stream',
headers={
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
'Access-Control-Allow-Origin': '*',
}
)
@app.route('/health', methods=['GET'])
def health():
return jsonify({"status": "healthy", "type": "openai-mock"})
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='OpenAI-compatible mock server')
parser.add_argument('--port', type=int, default=8081,
help='Port to run the server on (default: 8081)')
args = parser.parse_args()
port = args.port
models = get_models()
print("Starting OpenAI-compatible mock server...")
print(f"- /models endpoint with: {[m['id'] for m in models['data']]}")
print("- OpenAI-formatted chat/completion responses with dynamic content")
print("- Streaming support with valid SSE format")
print(f"- Listening on: http://0.0.0.0:{port}")
app.run(host='0.0.0.0', port=port, debug=False)

View file

@ -0,0 +1,143 @@
apiVersion: v1
data:
stack_run_config.yaml: |
version: '2'
image_name: kubernetes-benchmark-demo
apis:
- agents
- inference
- safety
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: vllm-inference
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=http://localhost:8000/v1}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: vllm-safety
provider_type: remote::vllm
config:
url: ${env.VLLM_SAFETY_URL:=http://localhost:8000/v1}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: mock-vllm-inference
provider_type: remote::vllm
config:
url: http://openai-mock-service:${env.MOCK_INFERENCE_PORT}
max_tokens: 4096
api_token: fake
tls_verify: false
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
vector_io:
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
provider_type: remote::chromadb
config:
url: ${env.CHROMADB_URL:=}
kvstore:
type: postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
responses_store:
type: postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:+}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:+}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
metadata_store:
type: postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
table_name: llamastack_kvstore
inference_store:
type: postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
models:
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
model_type: embedding
- model_id: ${env.INFERENCE_MODEL}
provider_id: vllm-inference
model_type: llm
- model_id: ${env.SAFETY_MODEL}
provider_id: vllm-safety
model_type: llm
- model_id: ${env.MOCK_INFERENCE_MODEL}
provider_id: mock-vllm-inference
model_type: llm
shields:
- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8323
kind: ConfigMap
metadata:
creationTimestamp: null
name: llama-stack-config

View file

@ -0,0 +1,87 @@
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: llama-benchmark-pvc
spec:
accessModes:
- ReadWriteOnce
resources:
requests:
storage: 1Gi
---
apiVersion: apps/v1
kind: Deployment
metadata:
name: llama-stack-benchmark-server
spec:
replicas: 1
selector:
matchLabels:
app.kubernetes.io/name: llama-stack-benchmark
app.kubernetes.io/component: server
template:
metadata:
labels:
app.kubernetes.io/name: llama-stack-benchmark
app.kubernetes.io/component: server
spec:
containers:
- name: llama-stack-benchmark
image: llamastack/distribution-starter:latest
imagePullPolicy: Always # since we have specified latest instead of a version
env:
- name: ENABLE_CHROMADB
value: "true"
- name: CHROMADB_URL
value: http://chromadb.default.svc.cluster.local:6000
- name: POSTGRES_HOST
value: postgres-server.default.svc.cluster.local
- name: POSTGRES_PORT
value: "5432"
- name: INFERENCE_MODEL
value: "${INFERENCE_MODEL}"
- name: SAFETY_MODEL
value: "${SAFETY_MODEL}"
- name: TAVILY_SEARCH_API_KEY
value: "${TAVILY_SEARCH_API_KEY}"
- name: MOCK_INFERENCE_PORT
value: "${MOCK_INFERENCE_PORT}"
- name: VLLM_URL
value: http://vllm-server.default.svc.cluster.local:8000/v1
- name: VLLM_MAX_TOKENS
value: "3072"
- name: VLLM_SAFETY_URL
value: http://vllm-server-safety.default.svc.cluster.local:8001/v1
- name: VLLM_TLS_VERIFY
value: "false"
- name: MOCK_INFERENCE_MODEL
value: "${MOCK_INFERENCE_MODEL}"
command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8323"]
ports:
- containerPort: 8323
volumeMounts:
- name: llama-storage
mountPath: /root/.llama
- name: llama-config
mountPath: /etc/config
volumes:
- name: llama-storage
persistentVolumeClaim:
claimName: llama-benchmark-pvc
- name: llama-config
configMap:
name: llama-stack-config
---
apiVersion: v1
kind: Service
metadata:
name: llama-stack-benchmark-service
spec:
selector:
app.kubernetes.io/name: llama-stack-benchmark
app.kubernetes.io/component: server
ports:
- name: http
port: 8323
targetPort: 8323
type: ClusterIP

View file

@ -0,0 +1,136 @@
version: '2'
image_name: kubernetes-benchmark-demo
apis:
- agents
- inference
- safety
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: vllm-inference
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=http://localhost:8000/v1}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: vllm-safety
provider_type: remote::vllm
config:
url: ${env.VLLM_SAFETY_URL:=http://localhost:8000/v1}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: mock-vllm-inference
provider_type: remote::vllm
config:
url: http://openai-mock-service:${env.MOCK_INFERENCE_PORT}
max_tokens: 4096
api_token: fake
tls_verify: false
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
vector_io:
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
provider_type: remote::chromadb
config:
url: ${env.CHROMADB_URL:=}
kvstore:
type: postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
responses_store:
type: postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:+}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:+}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
metadata_store:
type: postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
table_name: llamastack_kvstore
inference_store:
type: postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
models:
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
model_type: embedding
- model_id: ${env.INFERENCE_MODEL}
provider_id: vllm-inference
model_type: llm
- model_id: ${env.SAFETY_MODEL}
provider_id: vllm-safety
model_type: llm
- model_id: ${env.MOCK_INFERENCE_MODEL}
provider_id: mock-vllm-inference
model_type: llm
shields:
- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8323

View file

@ -40,19 +40,19 @@ spec:
value: "3072"
- name: VLLM_SAFETY_URL
value: http://vllm-server-safety.default.svc.cluster.local:8001/v1
- name: VLLM_TLS_VERIFY
value: "false"
- name: POSTGRES_HOST
value: postgres-server.default.svc.cluster.local
- name: POSTGRES_PORT
value: "5432"
- name: VLLM_TLS_VERIFY
value: "false"
- name: INFERENCE_MODEL
value: "${INFERENCE_MODEL}"
- name: SAFETY_MODEL
value: "${SAFETY_MODEL}"
- name: TAVILY_SEARCH_API_KEY
value: "${TAVILY_SEARCH_API_KEY}"
command: ["python", "-m", "llama_stack.core.server.server", "--config", "/etc/config/stack_run_config.yaml", "--port", "8321"]
command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8321"]
ports:
- containerPort: 8321
volumeMounts:

View file

@ -226,7 +226,7 @@ uv init
name = "llama-stack-provider-ollama"
version = "0.1.0"
description = "Ollama provider for Llama Stack"
requires-python = ">=3.10"
requires-python = ">=3.12"
dependencies = ["llama-stack", "pydantic", "ollama", "aiohttp"]
```

View file

@ -29,6 +29,7 @@ remote_runpod
remote_sambanova
remote_tgi
remote_together
remote_vertexai
remote_vllm
remote_watsonx
```

View file

@ -0,0 +1,40 @@
# remote::vertexai
## Description
Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages:
• Enterprise-grade security: Uses Google Cloud's security controls and IAM
• Better integration: Seamless integration with other Google Cloud services
• Advanced features: Access to additional Vertex AI features like model tuning and monitoring
• Authentication: Uses Google Cloud Application Default Credentials (ADC) instead of API keys
Configuration:
- Set VERTEX_AI_PROJECT environment variable (required)
- Set VERTEX_AI_LOCATION environment variable (optional, defaults to us-central1)
- Use Google Cloud Application Default Credentials or service account key
Authentication Setup:
Option 1 (Recommended): gcloud auth application-default login
Option 2: Set GOOGLE_APPLICATION_CREDENTIALS to service account key path
Available Models:
- vertex_ai/gemini-2.0-flash
- vertex_ai/gemini-2.5-flash
- vertex_ai/gemini-2.5-pro
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `project` | `<class 'str'>` | No | | Google Cloud project ID for Vertex AI |
| `location` | `<class 'str'>` | No | us-central1 | Google Cloud location for Vertex AI |
## Sample Configuration
```yaml
project: ${env.VERTEX_AI_PROJECT:=}
location: ${env.VERTEX_AI_LOCATION:=us-central1}
```

View file

@ -12,6 +12,18 @@ That means you'll get fast and efficient vector retrieval.
- Lightweight and easy to use
- Fully integrated with Llama Stack
- GPU support
- **Vector search** - FAISS supports pure vector similarity search using embeddings
## Search Modes
**Supported:**
- **Vector Search** (`mode="vector"`): Performs vector similarity search using embeddings
**Not Supported:**
- **Keyword Search** (`mode="keyword"`): Not supported by FAISS
- **Hybrid Search** (`mode="hybrid"`): Not supported by FAISS
> **Note**: FAISS is designed as a pure vector similarity search library. See the [FAISS GitHub repository](https://github.com/facebookresearch/faiss) for more details about FAISS's core functionality.
## Usage

View file

@ -21,5 +21,7 @@ kvstore:
## Deprecation Notice
⚠️ **Warning**: Please use the `inline::faiss` provider instead.
```{warning}
Please use the `inline::faiss` provider instead.
```

View file

@ -25,5 +25,7 @@ kvstore:
## Deprecation Notice
⚠️ **Warning**: Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.
```{warning}
Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.
```

View file

@ -11,6 +11,7 @@ That means you're not limited to storing vectors in memory or in a separate serv
- Easy to use
- Fully integrated with Llama Stack
- Supports all search modes: vector, keyword, and hybrid search (both inline and remote configurations)
## Usage
@ -101,6 +102,92 @@ vector_io:
- **`client_pem_path`**: Path to the **client certificate** file (required for mTLS).
- **`client_key_path`**: Path to the **client private key** file (required for mTLS).
## Search Modes
Milvus supports three different search modes for both inline and remote configurations:
### Vector Search
Vector search uses semantic similarity to find the most relevant chunks based on embedding vectors. This is the default search mode and works well for finding conceptually similar content.
```python
# Vector search example
search_response = client.vector_stores.search(
vector_store_id=vector_store.id,
query="What is machine learning?",
search_mode="vector",
max_num_results=5,
)
```
### Keyword Search
Keyword search uses traditional text-based matching to find chunks containing specific terms or phrases. This is useful when you need exact term matches.
```python
# Keyword search example
search_response = client.vector_stores.search(
vector_store_id=vector_store.id,
query="Python programming language",
search_mode="keyword",
max_num_results=5,
)
```
### Hybrid Search
Hybrid search combines both vector and keyword search methods to provide more comprehensive results. It leverages the strengths of both semantic similarity and exact term matching.
#### Basic Hybrid Search
```python
# Basic hybrid search example (uses RRF ranker with default impact_factor=60.0)
search_response = client.vector_stores.search(
vector_store_id=vector_store.id,
query="neural networks in Python",
search_mode="hybrid",
max_num_results=5,
)
```
**Note**: The default `impact_factor` value of 60.0 was empirically determined to be optimal in the original RRF research paper: ["Reciprocal Rank Fusion outperforms Condorcet and individual Rank Learning Methods"](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf) (Cormack et al., 2009).
#### Hybrid Search with RRF (Reciprocal Rank Fusion) Ranker
RRF combines rankings from vector and keyword search by using reciprocal ranks. The impact factor controls how much weight is given to higher-ranked results.
```python
# Hybrid search with custom RRF parameters
search_response = client.vector_stores.search(
vector_store_id=vector_store.id,
query="neural networks in Python",
search_mode="hybrid",
max_num_results=5,
ranking_options={
"ranker": {
"type": "rrf",
"impact_factor": 100.0, # Higher values give more weight to top-ranked results
}
},
)
```
#### Hybrid Search with Weighted Ranker
Weighted ranker linearly combines normalized scores from vector and keyword search. The alpha parameter controls the balance between the two search methods.
```python
# Hybrid search with weighted ranker
search_response = client.vector_stores.search(
vector_store_id=vector_store.id,
query="neural networks in Python",
search_mode="hybrid",
max_num_results=5,
ranking_options={
"ranker": {
"type": "weighted",
"alpha": 0.7, # 70% vector search, 30% keyword search
}
},
)
```
For detailed documentation on RRF and Weighted rankers, please refer to the [Milvus Reranking Guide](https://milvus.io/docs/reranking.md).
## Documentation
See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general.
@ -117,7 +204,10 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend |
| `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. |
> **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider.
```{note}
This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider.
```
## Sample Configuration

View file

@ -128,7 +128,9 @@ llama download --source huggingface --model-id Prompt-Guard-86M --ignore-pattern
**Important:** Set your environment variable `HF_TOKEN` or pass in `--hf-token` to the command to validate your access. You can find your token at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens).
> **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored.
```{tip}
Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored.
```
## List the downloaded models

View file

@ -152,7 +152,9 @@ llama download --source huggingface --model-id Prompt-Guard-86M --ignore-pattern
**Important:** Set your environment variable `HF_TOKEN` or pass in `--hf-token` to the command to validate your access. You can find your token at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens).
> **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored.
```{tip}
Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored.
```
## List the downloaded models

View file

@ -706,6 +706,7 @@ class Agents(Protocol):
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
"""Create a new OpenAI response.
@ -713,6 +714,7 @@ class Agents(Protocol):
:param input: Input message(s) to create the response.
:param model: The underlying LLM used for completions.
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
:param include: (Optional) Additional fields to include in the response.
:returns: An OpenAIResponseObject.
"""
...

View file

@ -170,6 +170,23 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
type: Literal["web_search_call"] = "web_search_call"
class OpenAIResponseOutputMessageFileSearchToolCallResults(BaseModel):
"""Search results returned by the file search operation.
:param attributes: (Optional) Key-value attributes associated with the file
:param file_id: Unique identifier of the file containing the result
:param filename: Name of the file containing the result
:param score: Relevance score for this search result (between 0 and 1)
:param text: Text content of the search result
"""
attributes: dict[str, Any]
file_id: str
filename: str
score: float
text: str
@json_schema_type
class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel):
"""File search tool call output message for OpenAI responses.
@ -185,7 +202,7 @@ class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel):
queries: list[str]
status: str
type: Literal["file_search_call"] = "file_search_call"
results: list[dict[str, Any]] | None = None
results: list[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None
@json_schema_type
@ -606,6 +623,62 @@ class OpenAIResponseObjectStreamResponseMcpCallCompleted(BaseModel):
type: Literal["response.mcp_call.completed"] = "response.mcp_call.completed"
@json_schema_type
class OpenAIResponseContentPartOutputText(BaseModel):
type: Literal["output_text"] = "output_text"
text: str
# TODO: add annotations, logprobs, etc.
@json_schema_type
class OpenAIResponseContentPartRefusal(BaseModel):
type: Literal["refusal"] = "refusal"
refusal: str
OpenAIResponseContentPart = Annotated[
OpenAIResponseContentPartOutputText | OpenAIResponseContentPartRefusal,
Field(discriminator="type"),
]
register_schema(OpenAIResponseContentPart, name="OpenAIResponseContentPart")
@json_schema_type
class OpenAIResponseObjectStreamResponseContentPartAdded(BaseModel):
"""Streaming event for when a new content part is added to a response item.
:param response_id: Unique identifier of the response containing this content
:param item_id: Unique identifier of the output item containing this content part
:param part: The content part that was added
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.content_part.added"
"""
response_id: str
item_id: str
part: OpenAIResponseContentPart
sequence_number: int
type: Literal["response.content_part.added"] = "response.content_part.added"
@json_schema_type
class OpenAIResponseObjectStreamResponseContentPartDone(BaseModel):
"""Streaming event for when a content part is completed.
:param response_id: Unique identifier of the response containing this content
:param item_id: Unique identifier of the output item containing this content part
:param part: The completed content part
:param sequence_number: Sequential number for ordering streaming events
:param type: Event type identifier, always "response.content_part.done"
"""
response_id: str
item_id: str
part: OpenAIResponseContentPart
sequence_number: int
type: Literal["response.content_part.done"] = "response.content_part.done"
OpenAIResponseObjectStream = Annotated[
OpenAIResponseObjectStreamResponseCreated
| OpenAIResponseObjectStreamResponseOutputItemAdded
@ -625,6 +698,8 @@ OpenAIResponseObjectStream = Annotated[
| OpenAIResponseObjectStreamResponseMcpCallInProgress
| OpenAIResponseObjectStreamResponseMcpCallFailed
| OpenAIResponseObjectStreamResponseMcpCallCompleted
| OpenAIResponseObjectStreamResponseContentPartAdded
| OpenAIResponseObjectStreamResponseContentPartDone
| OpenAIResponseObjectStreamResponseCompleted,
Field(discriminator="type"),
]

View file

@ -62,3 +62,13 @@ class SessionNotFoundError(ValueError):
def __init__(self, session_name: str) -> None:
message = f"Session '{session_name}' not found or access denied."
super().__init__(message)
class ModelTypeError(TypeError):
"""raised when a model is present but not the correct type"""
def __init__(self, model_name: str, model_type: str, expected_model_type: str) -> None:
message = (
f"Model '{model_name}' is of type '{model_type}' rather than the expected type '{expected_model_type}'"
)
super().__init__(message)

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum, StrEnum
from enum import Enum
from typing import Any, Protocol, runtime_checkable
from pydantic import BaseModel, Field
@ -15,27 +15,6 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
# OpenAI Categories to return in the response
class OpenAICategories(StrEnum):
"""
Required set of categories in moderations api response
"""
VIOLENCE = "violence"
VIOLENCE_GRAPHIC = "violence/graphic"
HARRASMENT = "harassment"
HARRASMENT_THREATENING = "harassment/threatening"
HATE = "hate"
HATE_THREATENING = "hate/threatening"
ILLICIT = "illicit"
ILLICIT_VIOLENT = "illicit/violent"
SEXUAL = "sexual"
SEXUAL_MINORS = "sexual/minors"
SELF_HARM = "self-harm"
SELF_HARM_INTENT = "self-harm/intent"
SELF_HARM_INSTRUCTIONS = "self-harm/instructions"
@json_schema_type
class ModerationObjectResults(BaseModel):
"""A moderation object.
@ -43,20 +22,6 @@ class ModerationObjectResults(BaseModel):
:param categories: A list of the categories, and whether they are flagged or not.
:param category_applied_input_types: A list of the categories along with the input type(s) that the score applies to.
:param category_scores: A list of the categories along with their scores as predicted by model.
Required set of categories that need to be in response
- violence
- violence/graphic
- harassment
- harassment/threatening
- hate
- hate/threatening
- illicit
- illicit/violent
- sexual
- sexual/minors
- self-harm
- self-harm/intent
- self-harm/instructions
"""
flagged: bool

View file

@ -91,7 +91,7 @@ def get_provider_dependencies(
def print_pip_install_help(config: BuildConfig):
normal_deps, special_deps = get_provider_dependencies(config)
normal_deps, special_deps, _ = get_provider_dependencies(config)
cprint(
f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}",

View file

@ -380,8 +380,17 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
json_content = json.dumps(convert_pydantic_to_json_value(result))
filtered_body = {k: v for k, v in body.items() if not isinstance(v, LibraryClientUploadFile)}
status_code = httpx.codes.OK
if options.method.upper() == "DELETE" and result is None:
status_code = httpx.codes.NO_CONTENT
if status_code == httpx.codes.NO_CONTENT:
json_content = ""
mock_response = httpx.Response(
status_code=httpx.codes.OK,
status_code=status_code,
content=json_content.encode("utf-8"),
headers={
"Content-Type": "application/json",

View file

@ -18,7 +18,7 @@ from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.common.errors import ModelNotFoundError
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
from llama_stack.apis.inference import (
BatchChatCompletionResponse,
BatchCompletionResponse,
@ -65,7 +65,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus, Routin
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="inference")
class InferenceRouter(Inference):
@ -177,6 +177,15 @@ class InferenceRouter(Inference):
encoded = self.formatter.encode_content(messages)
return len(encoded.tokens) if encoded and encoded.tokens else 0
async def _get_model(self, model_id: str, expected_model_type: str) -> Model:
"""takes a model id and gets model after ensuring that it is accessible and of the correct type"""
model = await self.routing_table.get_model(model_id)
if model is None:
raise ModelNotFoundError(model_id)
if model.model_type != expected_model_type:
raise ModelTypeError(model_id, model.model_type, expected_model_type)
return model
async def chat_completion(
self,
model_id: str,
@ -195,11 +204,7 @@ class InferenceRouter(Inference):
)
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.routing_table.get_model(model_id)
if model is None:
raise ModelNotFoundError(model_id)
if model.model_type == ModelType.embedding:
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
model = await self._get_model(model_id, ModelType.llm)
if tool_config:
if tool_choice and tool_choice != tool_config.tool_choice:
raise ValueError("tool_choice and tool_config.tool_choice must match")
@ -301,11 +306,7 @@ class InferenceRouter(Inference):
logger.debug(
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
)
model = await self.routing_table.get_model(model_id)
if model is None:
raise ModelNotFoundError(model_id)
if model.model_type == ModelType.embedding:
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
model = await self._get_model(model_id, ModelType.llm)
provider = await self.routing_table.get_provider_impl(model_id)
params = dict(
model_id=model_id,
@ -355,11 +356,7 @@ class InferenceRouter(Inference):
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
logger.debug(f"InferenceRouter.embeddings: {model_id}")
model = await self.routing_table.get_model(model_id)
if model is None:
raise ModelNotFoundError(model_id)
if model.model_type == ModelType.llm:
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
await self._get_model(model_id, ModelType.embedding)
provider = await self.routing_table.get_provider_impl(model_id)
return await provider.embeddings(
model_id=model_id,
@ -395,12 +392,7 @@ class InferenceRouter(Inference):
logger.debug(
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
)
model_obj = await self.routing_table.get_model(model)
if model_obj is None:
raise ModelNotFoundError(model)
if model_obj.model_type == ModelType.embedding:
raise ValueError(f"Model '{model}' is an embedding model and does not support completions")
model_obj = await self._get_model(model, ModelType.llm)
params = dict(
model=model_obj.identifier,
prompt=prompt,
@ -476,11 +468,7 @@ class InferenceRouter(Inference):
logger.debug(
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
)
model_obj = await self.routing_table.get_model(model)
if model_obj is None:
raise ModelNotFoundError(model)
if model_obj.model_type == ModelType.embedding:
raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions")
model_obj = await self._get_model(model, ModelType.llm)
# Use the OpenAI client for a bit of extra input validation without
# exposing the OpenAI client itself as part of our API surface
@ -567,12 +555,7 @@ class InferenceRouter(Inference):
logger.debug(
f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}",
)
model_obj = await self.routing_table.get_model(model)
if model_obj is None:
raise ModelNotFoundError(model)
if model_obj.model_type != ModelType.embedding:
raise ValueError(f"Model '{model}' is not an embedding model")
model_obj = await self._get_model(model, ModelType.embedding)
params = dict(
model=model_obj.identifier,
input=input,
@ -871,4 +854,5 @@ class InferenceRouter(Inference):
model=model.identifier,
object="chat.completion",
)
logger.debug(f"InferenceRouter.completion_response: {final_response}")
await self.store.store_chat_completion(final_response, messages)

View file

@ -10,7 +10,7 @@ from llama_stack.apis.inference import (
Message,
)
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.safety.safety import ModerationObject, OpenAICategories
from llama_stack.apis.safety.safety import ModerationObject
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
@ -82,20 +82,5 @@ class SafetyRouter(Safety):
input=input,
model=model,
)
self._validate_required_categories_exist(response)
return response
def _validate_required_categories_exist(self, response: ModerationObject) -> None:
"""Validate the ProviderImpl response contains the required Open AI moderations categories."""
required_categories = list(map(str, OpenAICategories))
categories = response.results[0].categories
category_applied_input_types = response.results[0].category_applied_input_types
category_scores = response.results[0].category_scores
for i in [categories, category_applied_input_types, category_scores]:
if not set(required_categories).issubset(set(i.keys())):
raise ValueError(
f"ProviderImpl response is missing required categories: {set(required_categories) - set(i.keys())}"
)

View file

@ -63,6 +63,8 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def get_provider_impl(self, model_id: str) -> Any:
model = await lookup_model(self, model_id)
if model.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider {model.provider_id} not found in the routing table")
return self.impls_by_provider_id[model.provider_id]
async def register_model(

View file

@ -124,10 +124,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
return toolgroup
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
tool_group = await self.get_tool_group(toolgroup_id)
if tool_group is None:
raise ToolGroupNotFoundError(toolgroup_id)
await self.unregister_object(tool_group)
await self.unregister_object(await self.get_tool_group(toolgroup_id))
async def shutdown(self) -> None:
pass

View file

@ -8,7 +8,7 @@ from typing import Any
from pydantic import TypeAdapter
from llama_stack.apis.common.errors import ModelNotFoundError, VectorStoreNotFoundError
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError, VectorStoreNotFoundError
from llama_stack.apis.models import ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
@ -66,7 +66,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
if model is None:
raise ModelNotFoundError(embedding_model)
if model.model_type != ModelType.embedding:
raise ValueError(f"Model {embedding_model} is not an embedding model")
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
if "embedding_dimension" not in model.metadata:
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
vector_db_data = {

View file

@ -21,10 +21,11 @@ from importlib.metadata import version as parse_version
from pathlib import Path
from typing import Annotated, Any, get_origin
import httpx
import rich.pretty
import yaml
from aiohttp import hdrs
from fastapi import Body, FastAPI, HTTPException, Request
from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
@ -115,7 +116,7 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
if isinstance(exc, RequestValidationError):
return HTTPException(
status_code=400,
status_code=httpx.codes.BAD_REQUEST,
detail={
"errors": [
{
@ -128,20 +129,20 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
},
)
elif isinstance(exc, ValueError):
return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}")
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}")
elif isinstance(exc, BadRequestError):
return HTTPException(status_code=400, detail=str(exc))
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=str(exc))
elif isinstance(exc, PermissionError | AccessDeniedError):
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
return HTTPException(status_code=httpx.codes.FORBIDDEN, detail=f"Permission denied: {str(exc)}")
elif isinstance(exc, asyncio.TimeoutError | TimeoutError):
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
return HTTPException(status_code=httpx.codes.GATEWAY_TIMEOUT, detail=f"Operation timed out: {str(exc)}")
elif isinstance(exc, NotImplementedError):
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
return HTTPException(status_code=httpx.codes.NOT_IMPLEMENTED, detail=f"Not implemented: {str(exc)}")
elif isinstance(exc, AuthenticationRequiredError):
return HTTPException(status_code=401, detail=f"Authentication required: {str(exc)}")
return HTTPException(status_code=httpx.codes.UNAUTHORIZED, detail=f"Authentication required: {str(exc)}")
else:
return HTTPException(
status_code=500,
status_code=httpx.codes.INTERNAL_SERVER_ERROR,
detail="Internal server error: An unexpected error occurred.",
)
@ -180,7 +181,6 @@ async def sse_generator(event_gen_coroutine):
event_gen = await event_gen_coroutine
async for item in event_gen:
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
logger.info("Generator cancelled")
if event_gen:
@ -236,6 +236,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
result = await maybe_await(value)
if isinstance(result, PaginatedResponse) and result.url is None:
result.url = route
if method.upper() == "DELETE" and result is None:
return Response(status_code=httpx.codes.NO_CONTENT)
return result
except Exception as e:
if logger.isEnabledFor(logging.DEBUG):
@ -352,7 +356,7 @@ class ClientVersionMiddleware:
await send(
{
"type": "http.response.start",
"status": 426,
"status": httpx.codes.UPGRADE_REQUIRED,
"headers": [[b"content-type", b"application/json"]],
}
)

View file

@ -14,6 +14,7 @@ distribution_spec:
- provider_type: remote::openai
- provider_type: remote::anthropic
- provider_type: remote::gemini
- provider_type: remote::vertexai
- provider_type: remote::groq
- provider_type: remote::sambanova
- provider_type: inline::sentence-transformers

View file

@ -65,6 +65,11 @@ providers:
provider_type: remote::gemini
config:
api_key: ${env.GEMINI_API_KEY:=}
- provider_id: ${env.VERTEX_AI_PROJECT:+vertexai}
provider_type: remote::vertexai
config:
project: ${env.VERTEX_AI_PROJECT:=}
location: ${env.VERTEX_AI_LOCATION:=us-central1}
- provider_id: groq
provider_type: remote::groq
config:

View file

@ -16,6 +16,7 @@ from llama_stack.distributions.template import DistributionTemplate, RunConfigSe
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.remote.vector_io.chroma import ChromaVectorIOConfig
def get_distribution_template() -> DistributionTemplate:
@ -71,9 +72,10 @@ def get_distribution_template() -> DistributionTemplate:
chromadb_provider = Provider(
provider_id="chromadb",
provider_type="remote::chromadb",
config={
"url": "${env.CHROMA_URL}",
},
config=ChromaVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}/",
url="${env.CHROMADB_URL:=}",
),
)
inference_model = ModelInput(

View file

@ -26,7 +26,10 @@ providers:
- provider_id: chromadb
provider_type: remote::chromadb
config:
url: ${env.CHROMA_URL}
url: ${env.CHROMADB_URL:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell/}/chroma_remote_registry.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard

View file

@ -22,7 +22,10 @@ providers:
- provider_id: chromadb
provider_type: remote::chromadb
config:
url: ${env.CHROMA_URL}
url: ${env.CHROMADB_URL:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell/}/chroma_remote_registry.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard

View file

@ -14,6 +14,7 @@ distribution_spec:
- provider_type: remote::openai
- provider_type: remote::anthropic
- provider_type: remote::gemini
- provider_type: remote::vertexai
- provider_type: remote::groq
- provider_type: remote::sambanova
- provider_type: inline::sentence-transformers

View file

@ -65,6 +65,11 @@ providers:
provider_type: remote::gemini
config:
api_key: ${env.GEMINI_API_KEY:=}
- provider_id: ${env.VERTEX_AI_PROJECT:+vertexai}
provider_type: remote::vertexai
config:
project: ${env.VERTEX_AI_PROJECT:=}
location: ${env.VERTEX_AI_LOCATION:=us-central1}
- provider_id: groq
provider_type: remote::groq
config:

View file

@ -56,6 +56,7 @@ ENABLED_INFERENCE_PROVIDERS = [
"fireworks",
"together",
"gemini",
"vertexai",
"groq",
"sambanova",
"anthropic",
@ -71,6 +72,7 @@ INFERENCE_PROVIDER_IDS = {
"tgi": "${env.TGI_URL:+tgi}",
"cerebras": "${env.CEREBRAS_API_KEY:+cerebras}",
"nvidia": "${env.NVIDIA_API_KEY:+nvidia}",
"vertexai": "${env.VERTEX_AI_PROJECT:+vertexai}",
}
@ -246,6 +248,14 @@ def get_distribution_template() -> DistributionTemplate:
"",
"Gemini API Key",
),
"VERTEX_AI_PROJECT": (
"",
"Google Cloud Project ID for Vertex AI",
),
"VERTEX_AI_LOCATION": (
"us-central1",
"Google Cloud Location for Vertex AI",
),
"SAMBANOVA_API_KEY": (
"",
"SambaNova API Key",

View file

@ -32,6 +32,7 @@ CATEGORIES = [
"tools",
"client",
"telemetry",
"openai_responses",
]
# Initialize category levels with default level
@ -99,7 +100,8 @@ def parse_environment_config(env_config: str) -> dict[str, int]:
Dict[str, int]: A dictionary mapping categories to their log levels.
"""
category_levels = {}
for pair in env_config.split(";"):
delimiter = ","
for pair in env_config.split(delimiter):
if not pair.strip():
continue

View file

@ -236,6 +236,7 @@ class ChatFormat:
arguments_json=json.dumps(tool_arguments),
)
)
content = ""
return RawMessage(
role="assistant",

View file

@ -68,6 +68,11 @@ from llama_stack.models.llama.datatypes import (
BuiltinTool,
ToolCall,
)
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
convert_openai_chat_completion_stream,
convert_tooldef_to_openai_tool,
)
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.telemetry import tracing
@ -510,16 +515,60 @@ class ChatAgent(ShieldRunnerMixin):
async with tracing.span("inference") as span:
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=self.tool_defs,
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
response_format=self.agent_config.response_format,
# Convert messages to OpenAI format
openai_messages = []
for message in input_messages:
openai_message = await convert_message_to_openai_dict(message)
openai_messages.append(openai_message)
# Convert tool definitions to OpenAI format
openai_tools = None
if self.tool_defs:
openai_tools = []
for tool_def in self.tool_defs:
openai_tool = convert_tooldef_to_openai_tool(tool_def)
openai_tools.append(openai_tool)
# Extract tool_choice from tool_config for OpenAI compatibility
# Note: tool_choice can only be provided when tools are also provided
tool_choice = None
if openai_tools and self.agent_config.tool_config and self.agent_config.tool_config.tool_choice:
tool_choice = (
self.agent_config.tool_config.tool_choice.value
if hasattr(self.agent_config.tool_config.tool_choice, "value")
else str(self.agent_config.tool_config.tool_choice)
)
# Convert sampling params to OpenAI format (temperature, top_p, max_tokens)
temperature = None
top_p = None
max_tokens = None
if sampling_params:
if hasattr(sampling_params.strategy, "temperature"):
temperature = sampling_params.strategy.temperature
if hasattr(sampling_params.strategy, "top_p"):
top_p = sampling_params.strategy.top_p
if sampling_params.max_tokens:
max_tokens = sampling_params.max_tokens
# Use OpenAI chat completion
openai_stream = await self.inference_api.openai_chat_completion(
model=self.agent_config.model,
messages=openai_messages,
tools=openai_tools if openai_tools else None,
tool_choice=tool_choice,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
stream=True,
sampling_params=sampling_params,
tool_config=self.agent_config.tool_config,
):
)
# Convert OpenAI stream back to Llama Stack format
response_stream = convert_openai_chat_completion_stream(
openai_stream, enable_incremental_tool_calls=True
)
async for chunk in response_stream:
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
continue

View file

@ -48,8 +48,8 @@ from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from .agent_instance import ChatAgent
from .config import MetaReferenceAgentsImplConfig
from .openai_responses import OpenAIResponsesImpl
from .persistence import AgentInfo
from .responses.openai_responses import OpenAIResponsesImpl
logger = logging.getLogger()
@ -327,10 +327,21 @@ class MetaReferenceAgentsImpl(Agents):
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10,
) -> OpenAIResponseObject:
return await self.openai_responses_impl.create_openai_response(
input, model, instructions, previous_response_id, store, stream, temperature, text, tools, max_infer_iters
input,
model,
instructions,
previous_response_id,
store,
stream,
temperature,
text,
tools,
include,
max_infer_iters,
)
async def list_openai_responses(

View file

@ -1,880 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
import time
import uuid
from collections.abc import AsyncIterator
from typing import Any
from openai.types.chat import ChatCompletionToolParam
from pydantic import BaseModel
from llama_stack.apis.agents import Order
from llama_stack.apis.agents.openai_responses import (
AllowedToolsFilter,
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIDeleteResponseObject,
OpenAIResponseInput,
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContent,
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseInputToolFileSearch,
OpenAIResponseInputToolMCP,
OpenAIResponseMessage,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseObjectStreamResponseOutputTextDelta,
OpenAIResponseOutput,
OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFileSearchToolCall,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseOutputMessageWebSearchToolCall,
OpenAIResponseText,
OpenAIResponseTextFormat,
WebSearchToolTypes,
)
from llama_stack.apis.common.content_types import TextContentItem
from llama_stack.apis.inference import (
Inference,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIDeveloperMessageParam,
OpenAIImageURL,
OpenAIJSONSchema,
OpenAIMessageParam,
OpenAIResponseFormatJSONObject,
OpenAIResponseFormatJSONSchema,
OpenAIResponseFormatParam,
OpenAIResponseFormatText,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
logger = get_logger(name=__name__, category="openai_responses")
OPENAI_RESPONSES_PREFIX = "openai_responses:"
async def _convert_response_content_to_chat_content(
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent],
) -> str | list[OpenAIChatCompletionContentPartParam]:
"""
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
The content schemas of each API look similar, but are not exactly the same.
"""
if isinstance(content, str):
return content
converted_parts = []
for content_part in content:
if isinstance(content_part, OpenAIResponseInputMessageContentText):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
elif isinstance(content_part, OpenAIResponseInputMessageContentImage):
if content_part.image_url:
image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail)
converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
elif isinstance(content_part, str):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part))
else:
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support content type '{type(content_part)}' in this context"
)
return converted_parts
async def _convert_response_input_to_chat_messages(
input: str | list[OpenAIResponseInput],
) -> list[OpenAIMessageParam]:
"""
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
"""
messages: list[OpenAIMessageParam] = []
if isinstance(input, list):
for input_item in input:
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
messages.append(
OpenAIToolMessageParam(
content=input_item.output,
tool_call_id=input_item.call_id,
)
)
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id=input_item.call_id,
function=OpenAIChatCompletionToolCallFunction(
name=input_item.name,
arguments=input_item.arguments,
),
)
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
else:
content = await _convert_response_content_to_chat_content(input_item.content)
message_type = await _get_message_type_by_role(input_item.role)
if message_type is None:
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
)
messages.append(message_type(content=content))
else:
messages.append(OpenAIUserMessageParam(content=input))
return messages
async def _convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
"""
Convert an OpenAI Chat Completion choice into an OpenAI Response output message.
"""
output_content = ""
if isinstance(choice.message.content, str):
output_content = choice.message.content
elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam):
output_content = choice.message.content.text
else:
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
)
return OpenAIResponseMessage(
id=f"msg_{uuid.uuid4()}",
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
status="completed",
role="assistant",
)
async def _convert_response_text_to_chat_response_format(text: OpenAIResponseText) -> OpenAIResponseFormatParam:
"""
Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format.
"""
if not text.format or text.format["type"] == "text":
return OpenAIResponseFormatText(type="text")
if text.format["type"] == "json_object":
return OpenAIResponseFormatJSONObject()
if text.format["type"] == "json_schema":
return OpenAIResponseFormatJSONSchema(
json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"])
)
raise ValueError(f"Unsupported text format: {text.format}")
async def _get_message_type_by_role(role: str):
role_to_type = {
"user": OpenAIUserMessageParam,
"system": OpenAISystemMessageParam,
"assistant": OpenAIAssistantMessageParam,
"developer": OpenAIDeveloperMessageParam,
}
return role_to_type.get(role)
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
input_items: ListOpenAIResponseInputItem
response: OpenAIResponseObject
class ChatCompletionContext(BaseModel):
model: str
messages: list[OpenAIMessageParam]
response_tools: list[OpenAIResponseInputTool] | None = None
chat_tools: list[ChatCompletionToolParam] | None = None
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
temperature: float | None
response_format: OpenAIResponseFormatParam
class OpenAIResponsesImpl:
def __init__(
self,
inference_api: Inference,
tool_groups_api: ToolGroups,
tool_runtime_api: ToolRuntime,
responses_store: ResponsesStore,
vector_io_api: VectorIO, # VectorIO
):
self.inference_api = inference_api
self.tool_groups_api = tool_groups_api
self.tool_runtime_api = tool_runtime_api
self.responses_store = responses_store
self.vector_io_api = vector_io_api
async def _prepend_previous_response(
self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None
):
if previous_response_id:
previous_response_with_input = await self.responses_store.get_response_object(previous_response_id)
# previous response input items
new_input_items = previous_response_with_input.input
# previous response output items
new_input_items.extend(previous_response_with_input.output)
# new input items from the current request
if isinstance(input, str):
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
else:
new_input_items.extend(input)
input = new_input_items
return input
async def _prepend_instructions(self, messages, instructions):
if instructions:
messages.insert(0, OpenAISystemMessageParam(content=instructions))
async def get_openai_response(
self,
response_id: str,
) -> OpenAIResponseObject:
response_with_input = await self.responses_store.get_response_object(response_id)
return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"})
async def list_openai_responses(
self,
after: str | None = None,
limit: int | None = 50,
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIResponseObject:
return await self.responses_store.list_responses(after, limit, model, order)
async def list_openai_response_input_items(
self,
response_id: str,
after: str | None = None,
before: str | None = None,
include: list[str] | None = None,
limit: int | None = 20,
order: Order | None = Order.desc,
) -> ListOpenAIResponseInputItem:
"""List input items for a given OpenAI response.
:param response_id: The ID of the response to retrieve input items for.
:param after: An item ID to list items after, used for pagination.
:param before: An item ID to list items before, used for pagination.
:param include: Additional fields to include in the response.
:param limit: A limit on the number of objects to be returned.
:param order: The order to return the input items in.
:returns: An ListOpenAIResponseInputItem.
"""
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
async def _store_response(
self,
response: OpenAIResponseObject,
input: str | list[OpenAIResponseInput],
) -> None:
new_input_id = f"msg_{uuid.uuid4()}"
if isinstance(input, str):
# synthesize a message from the input string
input_content = OpenAIResponseInputMessageContentText(text=input)
input_content_item = OpenAIResponseMessage(
role="user",
content=[input_content],
id=new_input_id,
)
input_items_data = [input_content_item]
else:
# we already have a list of messages
input_items_data = []
for input_item in input:
if isinstance(input_item, OpenAIResponseMessage):
# These may or may not already have an id, so dump to dict, check for id, and add if missing
input_item_dict = input_item.model_dump()
if "id" not in input_item_dict:
input_item_dict["id"] = new_input_id
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
else:
input_items_data.append(input_item)
await self.responses_store.store_response_object(
response_object=response,
input=input_items_data,
)
async def create_openai_response(
self,
input: str | list[OpenAIResponseInput],
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
max_infer_iters: int | None = 10,
):
stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
stream_gen = self._create_streaming_response(
input=input,
model=model,
instructions=instructions,
previous_response_id=previous_response_id,
store=store,
temperature=temperature,
text=text,
tools=tools,
max_infer_iters=max_infer_iters,
)
if stream:
return stream_gen
else:
response = None
async for stream_chunk in stream_gen:
if stream_chunk.type == "response.completed":
if response is not None:
raise ValueError("The response stream completed multiple times! Earlier response: {response}")
response = stream_chunk.response
# don't leave the generator half complete!
if response is None:
raise ValueError("The response stream never completed")
return response
async def _create_streaming_response(
self,
input: str | list[OpenAIResponseInput],
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
store: bool | None = True,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
max_infer_iters: int | None = 10,
) -> AsyncIterator[OpenAIResponseObjectStream]:
output_messages: list[OpenAIResponseOutput] = []
# Input preprocessing
input = await self._prepend_previous_response(input, previous_response_id)
messages = await _convert_response_input_to_chat_messages(input)
await self._prepend_instructions(messages, instructions)
# Structured outputs
response_format = await _convert_response_text_to_chat_response_format(text)
# Tool setup, TODO: refactor this slightly since this can also yield events
chat_tools, mcp_tool_to_server, mcp_list_message = (
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
)
if mcp_list_message:
output_messages.append(mcp_list_message)
ctx = ChatCompletionContext(
model=model,
messages=messages,
response_tools=tools,
chat_tools=chat_tools,
mcp_tool_to_server=mcp_tool_to_server,
temperature=temperature,
response_format=response_format,
)
# Create initial response and emit response.created immediately
response_id = f"resp-{uuid.uuid4()}"
created_at = int(time.time())
initial_response = OpenAIResponseObject(
created_at=created_at,
id=response_id,
model=model,
object="response",
status="in_progress",
output=output_messages.copy(),
text=text,
)
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
n_iter = 0
messages = ctx.messages.copy()
while True:
completion_result = await self.inference_api.openai_chat_completion(
model=ctx.model,
messages=messages,
tools=ctx.chat_tools,
stream=True,
temperature=ctx.temperature,
response_format=ctx.response_format,
)
# Process streaming chunks and build complete response
chat_response_id = ""
chat_response_content = []
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
chunk_created = 0
chunk_model = ""
chunk_finish_reason = ""
sequence_number = 0
# Create a placeholder message item for delta events
message_item_id = f"msg_{uuid.uuid4()}"
async for chunk in completion_result:
chat_response_id = chunk.id
chunk_created = chunk.created
chunk_model = chunk.model
for chunk_choice in chunk.choices:
# Emit incremental text content as delta events
if chunk_choice.delta.content:
sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
content_index=0,
delta=chunk_choice.delta.content,
item_id=message_item_id,
output_index=0,
sequence_number=sequence_number,
)
# Collect content for final response
chat_response_content.append(chunk_choice.delta.content or "")
if chunk_choice.finish_reason:
chunk_finish_reason = chunk_choice.finish_reason
# Aggregate tool call arguments across chunks
if chunk_choice.delta.tool_calls:
for tool_call in chunk_choice.delta.tool_calls:
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
if response_tool_call:
# Don't attempt to concatenate arguments if we don't have any new argumentsAdd commentMore actions
if tool_call.function.arguments:
# Guard against an initial None argument before we concatenate
response_tool_call.function.arguments = (
response_tool_call.function.arguments or ""
) + tool_call.function.arguments
else:
tool_call_dict: dict[str, Any] = tool_call.model_dump()
tool_call_dict.pop("type", None)
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
chat_response_tool_calls[tool_call.index] = response_tool_call
# Convert collected chunks to complete response
if chat_response_tool_calls:
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
else:
tool_calls = None
assistant_message = OpenAIAssistantMessageParam(
content="".join(chat_response_content),
tool_calls=tool_calls,
)
current_response = OpenAIChatCompletion(
id=chat_response_id,
choices=[
OpenAIChoice(
message=assistant_message,
finish_reason=chunk_finish_reason,
index=0,
)
],
created=chunk_created,
model=chunk_model,
)
function_tool_calls = []
non_function_tool_calls = []
next_turn_messages = messages.copy()
for choice in current_response.choices:
next_turn_messages.append(choice.message)
if choice.message.tool_calls and tools:
for tool_call in choice.message.tool_calls:
if _is_function_tool_call(tool_call, tools):
function_tool_calls.append(tool_call)
else:
non_function_tool_calls.append(tool_call)
else:
output_messages.append(await _convert_chat_choice_to_response_message(choice))
# execute non-function tool calls
for tool_call in non_function_tool_calls:
tool_call_log, tool_response_message = await self._execute_tool_call(tool_call, ctx)
if tool_call_log:
output_messages.append(tool_call_log)
if tool_response_message:
next_turn_messages.append(tool_response_message)
for tool_call in function_tool_calls:
output_messages.append(
OpenAIResponseOutputMessageFunctionToolCall(
arguments=tool_call.function.arguments or "",
call_id=tool_call.id,
name=tool_call.function.name or "",
id=f"fc_{uuid.uuid4()}",
status="completed",
)
)
if not function_tool_calls and not non_function_tool_calls:
break
if function_tool_calls:
logger.info("Exiting inference loop since there is a function (client-side) tool call")
break
n_iter += 1
if n_iter >= max_infer_iters:
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {max_infer_iters=}")
break
messages = next_turn_messages
# Create final response
final_response = OpenAIResponseObject(
created_at=created_at,
id=response_id,
model=model,
object="response",
status="completed",
text=text,
output=output_messages,
)
# Emit response.completed
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
if store:
await self._store_response(
response=final_response,
input=input,
)
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
return await self.responses_store.delete_response_object(response_id)
async def _convert_response_tools_to_chat_tools(
self, tools: list[OpenAIResponseInputTool]
) -> tuple[
list[ChatCompletionToolParam],
dict[str, OpenAIResponseInputToolMCP],
OpenAIResponseOutput | None,
]:
from llama_stack.apis.agents.openai_responses import (
MCPListToolsTool,
)
from llama_stack.apis.tools import Tool
mcp_tool_to_server = {}
def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
tool_def = ToolDefinition(
tool_name=tool_name,
description=tool.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool.parameters
},
)
return convert_tooldef_to_openai_tool(tool_def)
mcp_list_message = None
chat_tools: list[ChatCompletionToolParam] = []
for input_tool in tools:
# TODO: Handle other tool types
if input_tool.type == "function":
chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
elif input_tool.type in WebSearchToolTypes:
tool_name = "web_search"
tool = await self.tool_groups_api.get_tool(tool_name)
if not tool:
raise ValueError(f"Tool {tool_name} not found")
chat_tools.append(make_openai_tool(tool_name, tool))
elif input_tool.type == "file_search":
tool_name = "knowledge_search"
tool = await self.tool_groups_api.get_tool(tool_name)
if not tool:
raise ValueError(f"Tool {tool_name} not found")
chat_tools.append(make_openai_tool(tool_name, tool))
elif input_tool.type == "mcp":
from llama_stack.providers.utils.tools.mcp import list_mcp_tools
always_allowed = None
never_allowed = None
if input_tool.allowed_tools:
if isinstance(input_tool.allowed_tools, list):
always_allowed = input_tool.allowed_tools
elif isinstance(input_tool.allowed_tools, AllowedToolsFilter):
always_allowed = input_tool.allowed_tools.always
never_allowed = input_tool.allowed_tools.never
tool_defs = await list_mcp_tools(
endpoint=input_tool.server_url,
headers=input_tool.headers or {},
)
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
id=f"mcp_list_{uuid.uuid4()}",
status="completed",
server_label=input_tool.server_label,
tools=[],
)
for t in tool_defs.data:
if never_allowed and t.name in never_allowed:
continue
if not always_allowed or t.name in always_allowed:
chat_tools.append(make_openai_tool(t.name, t))
if t.name in mcp_tool_to_server:
raise ValueError(f"Duplicate tool name {t.name} found for server {input_tool.server_label}")
mcp_tool_to_server[t.name] = input_tool
mcp_list_message.tools.append(
MCPListToolsTool(
name=t.name,
description=t.description,
input_schema={
"type": "object",
"properties": {
p.name: {
"type": p.parameter_type,
"description": p.description,
}
for p in t.parameters
},
"required": [p.name for p in t.parameters if p.required],
},
)
)
else:
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
return chat_tools, mcp_tool_to_server, mcp_list_message
async def _execute_knowledge_search_via_vector_store(
self,
query: str,
response_file_search_tool: OpenAIResponseInputToolFileSearch,
) -> ToolInvocationResult:
"""Execute knowledge search using vector_stores.search API with filters support."""
search_results = []
# Create search tasks for all vector stores
async def search_single_store(vector_store_id):
try:
search_response = await self.vector_io_api.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=response_file_search_tool.filters,
max_num_results=response_file_search_tool.max_num_results,
ranking_options=response_file_search_tool.ranking_options,
rewrite_query=False,
)
return search_response.data
except Exception as e:
logger.warning(f"Failed to search vector store {vector_store_id}: {e}")
return []
# Run all searches in parallel using gather
search_tasks = [search_single_store(vid) for vid in response_file_search_tool.vector_store_ids]
all_results = await asyncio.gather(*search_tasks)
# Flatten results
for results in all_results:
search_results.extend(results)
# Convert search results to tool result format matching memory.py
# Format the results as interleaved content similar to memory.py
content_items = []
content_items.append(
TextContentItem(
text=f"knowledge_search tool found {len(search_results)} chunks:\nBEGIN of knowledge_search tool results.\n"
)
)
for i, result_item in enumerate(search_results):
chunk_text = result_item.content[0].text if result_item.content else ""
metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}"
if result_item.attributes:
metadata_text += f", attributes: {result_item.attributes}"
text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n"
content_items.append(TextContentItem(text=text_content))
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
content_items.append(
TextContentItem(
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n',
)
)
return ToolInvocationResult(
content=content_items,
metadata={
"document_ids": [r.file_id for r in search_results],
"chunks": [r.content[0].text if r.content else "" for r in search_results],
"scores": [r.score for r in search_results],
},
)
async def _execute_tool_call(
self,
tool_call: OpenAIChatCompletionToolCall,
ctx: ChatCompletionContext,
) -> tuple[OpenAIResponseOutput | None, OpenAIMessageParam | None]:
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
tool_call_id = tool_call.id
function = tool_call.function
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
if not function or not tool_call_id or not function.name:
return None, None
error_exc = None
result = None
try:
if ctx.mcp_tool_to_server and function.name in ctx.mcp_tool_to_server:
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool
mcp_tool = ctx.mcp_tool_to_server[function.name]
result = await invoke_mcp_tool(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
tool_name=function.name,
kwargs=tool_kwargs,
)
elif function.name == "knowledge_search":
response_file_search_tool = next(
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), None
)
if response_file_search_tool:
# Use vector_stores.search API instead of knowledge_search tool
# to support filters and ranking_options
query = tool_kwargs.get("query", "")
result = await self._execute_knowledge_search_via_vector_store(
query=query,
response_file_search_tool=response_file_search_tool,
)
else:
result = await self.tool_runtime_api.invoke_tool(
tool_name=function.name,
kwargs=tool_kwargs,
)
except Exception as e:
error_exc = e
if function.name in ctx.mcp_tool_to_server:
from llama_stack.apis.agents.openai_responses import OpenAIResponseOutputMessageMCPCall
message = OpenAIResponseOutputMessageMCPCall(
id=tool_call_id,
arguments=function.arguments,
name=function.name,
server_label=ctx.mcp_tool_to_server[function.name].server_label,
)
if error_exc:
message.error = str(error_exc)
elif (result.error_code and result.error_code > 0) or result.error_message:
message.error = f"Error (code {result.error_code}): {result.error_message}"
elif result.content:
message.output = interleaved_content_as_str(result.content)
else:
if function.name == "web_search":
message = OpenAIResponseOutputMessageWebSearchToolCall(
id=tool_call_id,
status="completed",
)
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
message.status = "failed"
elif function.name == "knowledge_search":
message = OpenAIResponseOutputMessageFileSearchToolCall(
id=tool_call_id,
queries=[tool_kwargs.get("query", "")],
status="completed",
)
if "document_ids" in result.metadata:
message.results = []
for i, doc_id in enumerate(result.metadata["document_ids"]):
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
score = result.metadata["scores"][i] if "scores" in result.metadata else None
message.results.append(
{
"file_id": doc_id,
"filename": doc_id,
"text": text,
"score": score,
}
)
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
message.status = "failed"
else:
raise ValueError(f"Unknown tool {function.name} called")
input_message = None
if result and result.content:
if isinstance(result.content, str):
content = result.content
elif isinstance(result.content, list):
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
content = []
for item in result.content:
if isinstance(item, TextContentItem):
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
elif isinstance(item, ImageContentItem):
if item.image.data:
url = f"data:image;base64,{item.image.data}"
else:
url = item.image.url
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
else:
raise ValueError(f"Unknown result content type: {type(item)}")
content.append(part)
else:
raise ValueError(f"Unknown result content type: {type(result.content)}")
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
else:
text = str(error_exc)
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
return message, input_message
def _is_function_tool_call(
tool_call: OpenAIChatCompletionToolCall,
tools: list[OpenAIResponseInputTool],
) -> bool:
if not tool_call.function:
return False
for t in tools:
if t.type == "function" and t.name == tool_call.function.name:
return True
return False

View file

@ -191,7 +191,11 @@ class AgentPersistence:
sessions = []
for value in values:
try:
session_info = Session(**json.loads(value))
data = json.loads(value)
if "turn_id" in data:
continue
session_info = Session(**data)
sessions.append(session_info)
except Exception as e:
log.error(f"Error parsing session info: {e}")

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,271 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
import uuid
from collections.abc import AsyncIterator
from pydantic import BaseModel
from llama_stack.apis.agents import Order
from llama_stack.apis.agents.openai_responses import (
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIDeleteResponseObject,
OpenAIResponseInput,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseMessage,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseText,
OpenAIResponseTextFormat,
)
from llama_stack.apis.inference import (
Inference,
OpenAISystemMessageParam,
)
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from .streaming import StreamingResponseOrchestrator
from .tool_executor import ToolExecutor
from .types import ChatCompletionContext
from .utils import (
convert_response_input_to_chat_messages,
convert_response_text_to_chat_response_format,
)
logger = get_logger(name=__name__, category="responses")
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
input_items: ListOpenAIResponseInputItem
response: OpenAIResponseObject
class OpenAIResponsesImpl:
def __init__(
self,
inference_api: Inference,
tool_groups_api: ToolGroups,
tool_runtime_api: ToolRuntime,
responses_store: ResponsesStore,
vector_io_api: VectorIO, # VectorIO
):
self.inference_api = inference_api
self.tool_groups_api = tool_groups_api
self.tool_runtime_api = tool_runtime_api
self.responses_store = responses_store
self.vector_io_api = vector_io_api
self.tool_executor = ToolExecutor(
tool_groups_api=tool_groups_api,
tool_runtime_api=tool_runtime_api,
vector_io_api=vector_io_api,
)
async def _prepend_previous_response(
self,
input: str | list[OpenAIResponseInput],
previous_response_id: str | None = None,
):
if previous_response_id:
previous_response_with_input = await self.responses_store.get_response_object(previous_response_id)
# previous response input items
new_input_items = previous_response_with_input.input
# previous response output items
new_input_items.extend(previous_response_with_input.output)
# new input items from the current request
if isinstance(input, str):
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
else:
new_input_items.extend(input)
input = new_input_items
return input
async def _prepend_instructions(self, messages, instructions):
if instructions:
messages.insert(0, OpenAISystemMessageParam(content=instructions))
async def get_openai_response(
self,
response_id: str,
) -> OpenAIResponseObject:
response_with_input = await self.responses_store.get_response_object(response_id)
return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"})
async def list_openai_responses(
self,
after: str | None = None,
limit: int | None = 50,
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIResponseObject:
return await self.responses_store.list_responses(after, limit, model, order)
async def list_openai_response_input_items(
self,
response_id: str,
after: str | None = None,
before: str | None = None,
include: list[str] | None = None,
limit: int | None = 20,
order: Order | None = Order.desc,
) -> ListOpenAIResponseInputItem:
"""List input items for a given OpenAI response.
:param response_id: The ID of the response to retrieve input items for.
:param after: An item ID to list items after, used for pagination.
:param before: An item ID to list items before, used for pagination.
:param include: Additional fields to include in the response.
:param limit: A limit on the number of objects to be returned.
:param order: The order to return the input items in.
:returns: An ListOpenAIResponseInputItem.
"""
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
async def _store_response(
self,
response: OpenAIResponseObject,
input: str | list[OpenAIResponseInput],
) -> None:
new_input_id = f"msg_{uuid.uuid4()}"
if isinstance(input, str):
# synthesize a message from the input string
input_content = OpenAIResponseInputMessageContentText(text=input)
input_content_item = OpenAIResponseMessage(
role="user",
content=[input_content],
id=new_input_id,
)
input_items_data = [input_content_item]
else:
# we already have a list of messages
input_items_data = []
for input_item in input:
if isinstance(input_item, OpenAIResponseMessage):
# These may or may not already have an id, so dump to dict, check for id, and add if missing
input_item_dict = input_item.model_dump()
if "id" not in input_item_dict:
input_item_dict["id"] = new_input_id
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
else:
input_items_data.append(input_item)
await self.responses_store.store_response_object(
response_object=response,
input=input_items_data,
)
async def create_openai_response(
self,
input: str | list[OpenAIResponseInput],
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10,
):
stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
stream_gen = self._create_streaming_response(
input=input,
model=model,
instructions=instructions,
previous_response_id=previous_response_id,
store=store,
temperature=temperature,
text=text,
tools=tools,
max_infer_iters=max_infer_iters,
)
if stream:
return stream_gen
else:
response = None
async for stream_chunk in stream_gen:
if stream_chunk.type == "response.completed":
if response is not None:
raise ValueError("The response stream completed multiple times! Earlier response: {response}")
response = stream_chunk.response
# don't leave the generator half complete!
if response is None:
raise ValueError("The response stream never completed")
return response
async def _create_streaming_response(
self,
input: str | list[OpenAIResponseInput],
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
store: bool | None = True,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
max_infer_iters: int | None = 10,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Input preprocessing
input = await self._prepend_previous_response(input, previous_response_id)
messages = await convert_response_input_to_chat_messages(input)
await self._prepend_instructions(messages, instructions)
# Structured outputs
response_format = await convert_response_text_to_chat_response_format(text)
ctx = ChatCompletionContext(
model=model,
messages=messages,
response_tools=tools,
temperature=temperature,
response_format=response_format,
)
# Create orchestrator and delegate streaming logic
response_id = f"resp-{uuid.uuid4()}"
created_at = int(time.time())
orchestrator = StreamingResponseOrchestrator(
inference_api=self.inference_api,
ctx=ctx,
response_id=response_id,
created_at=created_at,
text=text,
max_infer_iters=max_infer_iters,
tool_executor=self.tool_executor,
)
# Stream the response
final_response = None
async for stream_chunk in orchestrator.create_response():
if stream_chunk.type == "response.completed":
final_response = stream_chunk.response
yield stream_chunk
# Store the response if requested
if store and final_response:
await self._store_response(
response=final_response,
input=input,
)
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
return await self.responses_store.delete_response_object(response_id)

View file

@ -0,0 +1,634 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from collections.abc import AsyncIterator
from typing import Any
from llama_stack.apis.agents.openai_responses import (
AllowedToolsFilter,
MCPListToolsTool,
OpenAIResponseContentPartOutputText,
OpenAIResponseInputTool,
OpenAIResponseInputToolMCP,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseContentPartAdded,
OpenAIResponseObjectStreamResponseContentPartDone,
OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta,
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone,
OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta,
OpenAIResponseObjectStreamResponseMcpCallArgumentsDone,
OpenAIResponseObjectStreamResponseMcpListToolsCompleted,
OpenAIResponseObjectStreamResponseMcpListToolsInProgress,
OpenAIResponseObjectStreamResponseOutputItemAdded,
OpenAIResponseObjectStreamResponseOutputItemDone,
OpenAIResponseObjectStreamResponseOutputTextDelta,
OpenAIResponseOutput,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseText,
WebSearchToolTypes,
)
from llama_stack.apis.inference import (
Inference,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionToolCall,
OpenAIChoice,
)
from llama_stack.log import get_logger
from .types import ChatCompletionContext, ChatCompletionResult
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
logger = get_logger(name=__name__, category="responses")
class StreamingResponseOrchestrator:
def __init__(
self,
inference_api: Inference,
ctx: ChatCompletionContext,
response_id: str,
created_at: int,
text: OpenAIResponseText,
max_infer_iters: int,
tool_executor, # Will be the tool execution logic from the main class
):
self.inference_api = inference_api
self.ctx = ctx
self.response_id = response_id
self.created_at = created_at
self.text = text
self.max_infer_iters = max_infer_iters
self.tool_executor = tool_executor
self.sequence_number = 0
# Store MCP tool mapping that gets built during tool processing
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
# Initialize output messages
output_messages: list[OpenAIResponseOutput] = []
# Create initial response and emit response.created immediately
initial_response = OpenAIResponseObject(
created_at=self.created_at,
id=self.response_id,
model=self.ctx.model,
object="response",
status="in_progress",
output=output_messages.copy(),
text=self.text,
)
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
# Process all tools (including MCP tools) and emit streaming events
if self.ctx.response_tools:
async for stream_event in self._process_tools(self.ctx.response_tools, output_messages):
yield stream_event
n_iter = 0
messages = self.ctx.messages.copy()
while True:
completion_result = await self.inference_api.openai_chat_completion(
model=self.ctx.model,
messages=messages,
tools=self.ctx.chat_tools,
stream=True,
temperature=self.ctx.temperature,
response_format=self.ctx.response_format,
)
# Process streaming chunks and build complete response
completion_result_data = None
async for stream_event_or_result in self._process_streaming_chunks(completion_result, output_messages):
if isinstance(stream_event_or_result, ChatCompletionResult):
completion_result_data = stream_event_or_result
else:
yield stream_event_or_result
if not completion_result_data:
raise ValueError("Streaming chunk processor failed to return completion data")
current_response = self._build_chat_completion(completion_result_data)
function_tool_calls, non_function_tool_calls, next_turn_messages = self._separate_tool_calls(
current_response, messages
)
# Handle choices with no tool calls
for choice in current_response.choices:
if not (choice.message.tool_calls and self.ctx.response_tools):
output_messages.append(await convert_chat_choice_to_response_message(choice))
# Execute tool calls and coordinate results
async for stream_event in self._coordinate_tool_execution(
function_tool_calls,
non_function_tool_calls,
completion_result_data,
output_messages,
next_turn_messages,
):
yield stream_event
if not function_tool_calls and not non_function_tool_calls:
break
if function_tool_calls:
logger.info("Exiting inference loop since there is a function (client-side) tool call")
break
n_iter += 1
if n_iter >= self.max_infer_iters:
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}")
break
messages = next_turn_messages
# Create final response
final_response = OpenAIResponseObject(
created_at=self.created_at,
id=self.response_id,
model=self.ctx.model,
object="response",
status="completed",
text=self.text,
output=output_messages,
)
# Emit response.completed
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list]:
"""Separate tool calls into function and non-function categories."""
function_tool_calls = []
non_function_tool_calls = []
next_turn_messages = messages.copy()
for choice in current_response.choices:
next_turn_messages.append(choice.message)
if choice.message.tool_calls and self.ctx.response_tools:
for tool_call in choice.message.tool_calls:
if is_function_tool_call(tool_call, self.ctx.response_tools):
function_tool_calls.append(tool_call)
else:
non_function_tool_calls.append(tool_call)
return function_tool_calls, non_function_tool_calls, next_turn_messages
async def _process_streaming_chunks(
self, completion_result, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream | ChatCompletionResult]:
"""Process streaming chunks and emit events, returning completion data."""
# Initialize result tracking
chat_response_id = ""
chat_response_content = []
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
chunk_created = 0
chunk_model = ""
chunk_finish_reason = ""
# Create a placeholder message item for delta events
message_item_id = f"msg_{uuid.uuid4()}"
# Track tool call items for streaming events
tool_call_item_ids: dict[int, str] = {}
# Track content parts for streaming events
content_part_emitted = False
async for chunk in completion_result:
chat_response_id = chunk.id
chunk_created = chunk.created
chunk_model = chunk.model
for chunk_choice in chunk.choices:
# Emit incremental text content as delta events
if chunk_choice.delta.content:
# Emit content_part.added event for first text chunk
if not content_part_emitted:
content_part_emitted = True
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartAdded(
response_id=self.response_id,
item_id=message_item_id,
part=OpenAIResponseContentPartOutputText(
text="", # Will be filled incrementally via text deltas
),
sequence_number=self.sequence_number,
)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
content_index=0,
delta=chunk_choice.delta.content,
item_id=message_item_id,
output_index=0,
sequence_number=self.sequence_number,
)
# Collect content for final response
chat_response_content.append(chunk_choice.delta.content or "")
if chunk_choice.finish_reason:
chunk_finish_reason = chunk_choice.finish_reason
# Aggregate tool call arguments across chunks
if chunk_choice.delta.tool_calls:
for tool_call in chunk_choice.delta.tool_calls:
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
# Create new tool call entry if this is the first chunk for this index
is_new_tool_call = response_tool_call is None
if is_new_tool_call:
tool_call_dict: dict[str, Any] = tool_call.model_dump()
tool_call_dict.pop("type", None)
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
chat_response_tool_calls[tool_call.index] = response_tool_call
# Create item ID for this tool call for streaming events
tool_call_item_id = f"fc_{uuid.uuid4()}"
tool_call_item_ids[tool_call.index] = tool_call_item_id
# Emit output_item.added event for the new function call
self.sequence_number += 1
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
arguments="", # Will be filled incrementally via delta events
call_id=tool_call.id or "",
name=tool_call.function.name if tool_call.function else "",
id=tool_call_item_id,
status="in_progress",
)
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
response_id=self.response_id,
item=function_call_item,
output_index=len(output_messages),
sequence_number=self.sequence_number,
)
# Stream tool call arguments as they arrive (differentiate between MCP and function calls)
if tool_call.function and tool_call.function.arguments:
tool_call_item_id = tool_call_item_ids[tool_call.index]
self.sequence_number += 1
# Check if this is an MCP tool call
is_mcp_tool = tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server
if is_mcp_tool:
# Emit MCP-specific argument delta event
yield OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta(
delta=tool_call.function.arguments,
item_id=tool_call_item_id,
output_index=len(output_messages),
sequence_number=self.sequence_number,
)
else:
# Emit function call argument delta event
yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(
delta=tool_call.function.arguments,
item_id=tool_call_item_id,
output_index=len(output_messages),
sequence_number=self.sequence_number,
)
# Accumulate arguments for final response (only for subsequent chunks)
if not is_new_tool_call:
response_tool_call.function.arguments = (
response_tool_call.function.arguments or ""
) + tool_call.function.arguments
# Emit arguments.done events for completed tool calls (differentiate between MCP and function calls)
for tool_call_index in sorted(chat_response_tool_calls.keys()):
tool_call_item_id = tool_call_item_ids[tool_call_index]
final_arguments = chat_response_tool_calls[tool_call_index].function.arguments or ""
tool_call_name = chat_response_tool_calls[tool_call_index].function.name
# Check if this is an MCP tool call
is_mcp_tool = tool_call_name and tool_call_name in self.mcp_tool_to_server
self.sequence_number += 1
done_event_cls = (
OpenAIResponseObjectStreamResponseMcpCallArgumentsDone
if is_mcp_tool
else OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone
)
yield done_event_cls(
arguments=final_arguments,
item_id=tool_call_item_id,
output_index=len(output_messages),
sequence_number=self.sequence_number,
)
# Emit content_part.done event if text content was streamed (before content gets cleared)
if content_part_emitted:
final_text = "".join(chat_response_content)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartDone(
response_id=self.response_id,
item_id=message_item_id,
part=OpenAIResponseContentPartOutputText(
text=final_text,
),
sequence_number=self.sequence_number,
)
# Clear content when there are tool calls (OpenAI spec behavior)
if chat_response_tool_calls:
chat_response_content = []
yield ChatCompletionResult(
response_id=chat_response_id,
content=chat_response_content,
tool_calls=chat_response_tool_calls,
created=chunk_created,
model=chunk_model,
finish_reason=chunk_finish_reason,
message_item_id=message_item_id,
tool_call_item_ids=tool_call_item_ids,
content_part_emitted=content_part_emitted,
)
def _build_chat_completion(self, result: ChatCompletionResult) -> OpenAIChatCompletion:
"""Build OpenAIChatCompletion from ChatCompletionResult."""
# Convert collected chunks to complete response
if result.tool_calls:
tool_calls = [result.tool_calls[i] for i in sorted(result.tool_calls.keys())]
else:
tool_calls = None
assistant_message = OpenAIAssistantMessageParam(
content=result.content_text,
tool_calls=tool_calls,
)
return OpenAIChatCompletion(
id=result.response_id,
choices=[
OpenAIChoice(
message=assistant_message,
finish_reason=result.finish_reason,
index=0,
)
],
created=result.created,
model=result.model,
)
async def _coordinate_tool_execution(
self,
function_tool_calls: list,
non_function_tool_calls: list,
completion_result_data: ChatCompletionResult,
output_messages: list[OpenAIResponseOutput],
next_turn_messages: list,
) -> AsyncIterator[OpenAIResponseObjectStream]:
"""Coordinate execution of both function and non-function tool calls."""
# Execute non-function tool calls
for tool_call in non_function_tool_calls:
# Find the item_id for this tool call
matching_item_id = None
for index, item_id in completion_result_data.tool_call_item_ids.items():
response_tool_call = completion_result_data.tool_calls.get(index)
if response_tool_call and response_tool_call.id == tool_call.id:
matching_item_id = item_id
break
# Use a fallback item_id if not found
if not matching_item_id:
matching_item_id = f"tc_{uuid.uuid4()}"
# Execute tool call with streaming
tool_call_log = None
tool_response_message = None
async for result in self.tool_executor.execute_tool_call(
tool_call,
self.ctx,
self.sequence_number,
len(output_messages),
matching_item_id,
self.mcp_tool_to_server,
):
if result.stream_event:
# Forward streaming events
self.sequence_number = result.sequence_number
yield result.stream_event
if result.final_output_message is not None:
tool_call_log = result.final_output_message
tool_response_message = result.final_input_message
self.sequence_number = result.sequence_number
if tool_call_log:
output_messages.append(tool_call_log)
# Emit output_item.done event for completed non-function tool call
if matching_item_id:
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id=self.response_id,
item=tool_call_log,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
if tool_response_message:
next_turn_messages.append(tool_response_message)
# Execute function tool calls (client-side)
for tool_call in function_tool_calls:
# Find the item_id for this tool call from our tracking dictionary
matching_item_id = None
for index, item_id in completion_result_data.tool_call_item_ids.items():
response_tool_call = completion_result_data.tool_calls.get(index)
if response_tool_call and response_tool_call.id == tool_call.id:
matching_item_id = item_id
break
# Use existing item_id or create new one if not found
final_item_id = matching_item_id or f"fc_{uuid.uuid4()}"
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
arguments=tool_call.function.arguments or "",
call_id=tool_call.id,
name=tool_call.function.name or "",
id=final_item_id,
status="completed",
)
output_messages.append(function_call_item)
# Emit output_item.done event for completed function call
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id=self.response_id,
item=function_call_item,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
async def _process_tools(
self, tools: list[OpenAIResponseInputTool], output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
"""Process all tools and emit appropriate streaming events."""
from openai.types.chat import ChatCompletionToolParam
from llama_stack.apis.tools import Tool
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
tool_def = ToolDefinition(
tool_name=tool_name,
description=tool.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool.parameters
},
)
return convert_tooldef_to_openai_tool(tool_def)
# Initialize chat_tools if not already set
if self.ctx.chat_tools is None:
self.ctx.chat_tools = []
for input_tool in tools:
if input_tool.type == "function":
self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
elif input_tool.type in WebSearchToolTypes:
tool_name = "web_search"
# Need to access tool_groups_api from tool_executor
tool = await self.tool_executor.tool_groups_api.get_tool(tool_name)
if not tool:
raise ValueError(f"Tool {tool_name} not found")
self.ctx.chat_tools.append(make_openai_tool(tool_name, tool))
elif input_tool.type == "file_search":
tool_name = "knowledge_search"
tool = await self.tool_executor.tool_groups_api.get_tool(tool_name)
if not tool:
raise ValueError(f"Tool {tool_name} not found")
self.ctx.chat_tools.append(make_openai_tool(tool_name, tool))
elif input_tool.type == "mcp":
async for stream_event in self._process_mcp_tool(input_tool, output_messages):
yield stream_event
else:
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
async def _process_mcp_tool(
self, mcp_tool: OpenAIResponseInputToolMCP, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
"""Process an MCP tool configuration and emit appropriate streaming events."""
from llama_stack.providers.utils.tools.mcp import list_mcp_tools
# Emit mcp_list_tools.in_progress
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseMcpListToolsInProgress(
sequence_number=self.sequence_number,
)
try:
# Parse allowed/never allowed tools
always_allowed = None
never_allowed = None
if mcp_tool.allowed_tools:
if isinstance(mcp_tool.allowed_tools, list):
always_allowed = mcp_tool.allowed_tools
elif isinstance(mcp_tool.allowed_tools, AllowedToolsFilter):
always_allowed = mcp_tool.allowed_tools.always
never_allowed = mcp_tool.allowed_tools.never
# Call list_mcp_tools
tool_defs = await list_mcp_tools(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
)
# Create the MCP list tools message
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
id=f"mcp_list_{uuid.uuid4()}",
server_label=mcp_tool.server_label,
tools=[],
)
# Process tools and update context
for t in tool_defs.data:
if never_allowed and t.name in never_allowed:
continue
if not always_allowed or t.name in always_allowed:
# Add to chat tools for inference
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
tool_def = ToolDefinition(
tool_name=t.name,
description=t.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in t.parameters
},
)
openai_tool = convert_tooldef_to_openai_tool(tool_def)
if self.ctx.chat_tools is None:
self.ctx.chat_tools = []
self.ctx.chat_tools.append(openai_tool)
# Add to MCP tool mapping
if t.name in self.mcp_tool_to_server:
raise ValueError(f"Duplicate tool name {t.name} found for server {mcp_tool.server_label}")
self.mcp_tool_to_server[t.name] = mcp_tool
# Add to MCP list message
mcp_list_message.tools.append(
MCPListToolsTool(
name=t.name,
description=t.description,
input_schema={
"type": "object",
"properties": {
p.name: {
"type": p.parameter_type,
"description": p.description,
}
for p in t.parameters
},
"required": [p.name for p in t.parameters if p.required],
},
)
)
# Add the MCP list message to output
output_messages.append(mcp_list_message)
# Emit output_item.added for the MCP list tools message
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
response_id=self.response_id,
item=mcp_list_message,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
# Emit mcp_list_tools.completed
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseMcpListToolsCompleted(
sequence_number=self.sequence_number,
)
# Emit output_item.done for the MCP list tools message
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id=self.response_id,
item=mcp_list_message,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
except Exception as e:
# TODO: Emit mcp_list_tools.failed event if needed
logger.exception(f"Failed to list MCP tools from {mcp_tool.server_url}: {e}")
raise

View file

@ -0,0 +1,379 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
from collections.abc import AsyncIterator
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputToolFileSearch,
OpenAIResponseInputToolMCP,
OpenAIResponseObjectStreamResponseMcpCallCompleted,
OpenAIResponseObjectStreamResponseMcpCallFailed,
OpenAIResponseObjectStreamResponseMcpCallInProgress,
OpenAIResponseObjectStreamResponseWebSearchCallCompleted,
OpenAIResponseObjectStreamResponseWebSearchCallInProgress,
OpenAIResponseObjectStreamResponseWebSearchCallSearching,
OpenAIResponseOutputMessageFileSearchToolCall,
OpenAIResponseOutputMessageFileSearchToolCallResults,
OpenAIResponseOutputMessageWebSearchToolCall,
)
from llama_stack.apis.common.content_types import (
ImageContentItem,
TextContentItem,
)
from llama_stack.apis.inference import (
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionToolCall,
OpenAIImageURL,
OpenAIToolMessageParam,
)
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
from .types import ChatCompletionContext, ToolExecutionResult
logger = get_logger(name=__name__, category="responses")
class ToolExecutor:
def __init__(
self,
tool_groups_api: ToolGroups,
tool_runtime_api: ToolRuntime,
vector_io_api: VectorIO,
):
self.tool_groups_api = tool_groups_api
self.tool_runtime_api = tool_runtime_api
self.vector_io_api = vector_io_api
async def execute_tool_call(
self,
tool_call: OpenAIChatCompletionToolCall,
ctx: ChatCompletionContext,
sequence_number: int,
output_index: int,
item_id: str,
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> AsyncIterator[ToolExecutionResult]:
tool_call_id = tool_call.id
function = tool_call.function
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
if not function or not tool_call_id or not function.name:
yield ToolExecutionResult(sequence_number=sequence_number)
return
# Emit progress events for tool execution start
async for event_result in self._emit_progress_events(
function.name, ctx, sequence_number, output_index, item_id, mcp_tool_to_server
):
sequence_number = event_result.sequence_number
yield event_result
# Execute the actual tool call
error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server)
# Emit completion events for tool execution
has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message))
async for event_result in self._emit_completion_events(
function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server
):
sequence_number = event_result.sequence_number
yield event_result
# Build result messages from tool execution
output_message, input_message = await self._build_result_messages(
function, tool_call_id, tool_kwargs, ctx, error_exc, result, has_error, mcp_tool_to_server
)
# Yield the final result
yield ToolExecutionResult(
sequence_number=sequence_number, final_output_message=output_message, final_input_message=input_message
)
async def _execute_knowledge_search_via_vector_store(
self,
query: str,
response_file_search_tool: OpenAIResponseInputToolFileSearch,
) -> ToolInvocationResult:
"""Execute knowledge search using vector_stores.search API with filters support."""
search_results = []
# Create search tasks for all vector stores
async def search_single_store(vector_store_id):
try:
search_response = await self.vector_io_api.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=response_file_search_tool.filters,
max_num_results=response_file_search_tool.max_num_results,
ranking_options=response_file_search_tool.ranking_options,
rewrite_query=False,
)
return search_response.data
except Exception as e:
logger.warning(f"Failed to search vector store {vector_store_id}: {e}")
return []
# Run all searches in parallel using gather
search_tasks = [search_single_store(vid) for vid in response_file_search_tool.vector_store_ids]
all_results = await asyncio.gather(*search_tasks)
# Flatten results
for results in all_results:
search_results.extend(results)
# Convert search results to tool result format matching memory.py
# Format the results as interleaved content similar to memory.py
content_items = []
content_items.append(
TextContentItem(
text=f"knowledge_search tool found {len(search_results)} chunks:\nBEGIN of knowledge_search tool results.\n"
)
)
for i, result_item in enumerate(search_results):
chunk_text = result_item.content[0].text if result_item.content else ""
metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}"
if result_item.attributes:
metadata_text += f", attributes: {result_item.attributes}"
text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n"
content_items.append(TextContentItem(text=text_content))
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
content_items.append(
TextContentItem(
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n',
)
)
return ToolInvocationResult(
content=content_items,
metadata={
"document_ids": [r.file_id for r in search_results],
"chunks": [r.content[0].text if r.content else "" for r in search_results],
"scores": [r.score for r in search_results],
},
)
async def _emit_progress_events(
self,
function_name: str,
ctx: ChatCompletionContext,
sequence_number: int,
output_index: int,
item_id: str,
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> AsyncIterator[ToolExecutionResult]:
"""Emit progress events for tool execution start."""
# Emit in_progress event based on tool type (only for tools with specific streaming events)
progress_event = None
if mcp_tool_to_server and function_name in mcp_tool_to_server:
sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
elif function_name == "web_search":
sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
# Note: knowledge_search and other custom tools don't have specific streaming events in OpenAI spec
if progress_event:
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
# For web search, emit searching event
if function_name == "web_search":
sequence_number += 1
searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
async def _execute_tool(
self,
function_name: str,
tool_kwargs: dict,
ctx: ChatCompletionContext,
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> tuple[Exception | None, any]:
"""Execute the tool and return error exception and result."""
error_exc = None
result = None
try:
if mcp_tool_to_server and function_name in mcp_tool_to_server:
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool
mcp_tool = mcp_tool_to_server[function_name]
result = await invoke_mcp_tool(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
tool_name=function_name,
kwargs=tool_kwargs,
)
elif function_name == "knowledge_search":
response_file_search_tool = next(
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
None,
)
if response_file_search_tool:
# Use vector_stores.search API instead of knowledge_search tool
# to support filters and ranking_options
query = tool_kwargs.get("query", "")
result = await self._execute_knowledge_search_via_vector_store(
query=query,
response_file_search_tool=response_file_search_tool,
)
else:
result = await self.tool_runtime_api.invoke_tool(
tool_name=function_name,
kwargs=tool_kwargs,
)
except Exception as e:
error_exc = e
return error_exc, result
async def _emit_completion_events(
self,
function_name: str,
ctx: ChatCompletionContext,
sequence_number: int,
output_index: int,
item_id: str,
has_error: bool,
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> AsyncIterator[ToolExecutionResult]:
"""Emit completion or failure events for tool execution."""
completion_event = None
if mcp_tool_to_server and function_name in mcp_tool_to_server:
sequence_number += 1
if has_error:
completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
sequence_number=sequence_number,
)
else:
completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
sequence_number=sequence_number,
)
elif function_name == "web_search":
sequence_number += 1
completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
# Note: knowledge_search and other custom tools don't have specific completion events in OpenAI spec
if completion_event:
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)
async def _build_result_messages(
self,
function,
tool_call_id: str,
tool_kwargs: dict,
ctx: ChatCompletionContext,
error_exc: Exception | None,
result: any,
has_error: bool,
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> tuple[any, any]:
"""Build output and input messages from tool execution results."""
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
# Build output message
if mcp_tool_to_server and function.name in mcp_tool_to_server:
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutputMessageMCPCall,
)
message = OpenAIResponseOutputMessageMCPCall(
id=tool_call_id,
arguments=function.arguments,
name=function.name,
server_label=mcp_tool_to_server[function.name].server_label,
)
if error_exc:
message.error = str(error_exc)
elif (result and result.error_code and result.error_code > 0) or (result and result.error_message):
message.error = f"Error (code {result.error_code}): {result.error_message}"
elif result and result.content:
message.output = interleaved_content_as_str(result.content)
else:
if function.name == "web_search":
message = OpenAIResponseOutputMessageWebSearchToolCall(
id=tool_call_id,
status="completed",
)
if has_error:
message.status = "failed"
elif function.name == "knowledge_search":
message = OpenAIResponseOutputMessageFileSearchToolCall(
id=tool_call_id,
queries=[tool_kwargs.get("query", "")],
status="completed",
)
if result and "document_ids" in result.metadata:
message.results = []
for i, doc_id in enumerate(result.metadata["document_ids"]):
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
score = result.metadata["scores"][i] if "scores" in result.metadata else None
message.results.append(
OpenAIResponseOutputMessageFileSearchToolCallResults(
file_id=doc_id,
filename=doc_id,
text=text,
score=score,
attributes={},
)
)
if has_error:
message.status = "failed"
else:
raise ValueError(f"Unknown tool {function.name} called")
# Build input message
input_message = None
if result and result.content:
if isinstance(result.content, str):
content = result.content
elif isinstance(result.content, list):
content = []
for item in result.content:
if isinstance(item, TextContentItem):
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
elif isinstance(item, ImageContentItem):
if item.image.data:
url = f"data:image;base64,{item.image.data}"
else:
url = item.image.url
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
else:
raise ValueError(f"Unknown result content type: {type(item)}")
content.append(part)
else:
raise ValueError(f"Unknown result content type: {type(result.content)}")
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
else:
text = str(error_exc) if error_exc else "Tool execution failed"
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
return message, input_message

View file

@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from dataclasses import dataclass
from openai.types.chat import ChatCompletionToolParam
from pydantic import BaseModel
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputTool,
OpenAIResponseObjectStream,
OpenAIResponseOutput,
)
from llama_stack.apis.inference import OpenAIChatCompletionToolCall, OpenAIMessageParam, OpenAIResponseFormatParam
class ToolExecutionResult(BaseModel):
"""Result of streaming tool execution."""
stream_event: OpenAIResponseObjectStream | None = None
sequence_number: int
final_output_message: OpenAIResponseOutput | None = None
final_input_message: OpenAIMessageParam | None = None
@dataclass
class ChatCompletionResult:
"""Result of processing streaming chat completion chunks."""
response_id: str
content: list[str]
tool_calls: dict[int, OpenAIChatCompletionToolCall]
created: int
model: str
finish_reason: str
message_item_id: str # For streaming events
tool_call_item_ids: dict[int, str] # For streaming events
content_part_emitted: bool # Tracking state
@property
def content_text(self) -> str:
"""Get joined content as string."""
return "".join(self.content)
@property
def has_tool_calls(self) -> bool:
"""Check if there are any tool calls."""
return bool(self.tool_calls)
class ChatCompletionContext(BaseModel):
model: str
messages: list[OpenAIMessageParam]
response_tools: list[OpenAIResponseInputTool] | None = None
chat_tools: list[ChatCompletionToolParam] | None = None
temperature: float | None
response_format: OpenAIResponseFormatParam

View file

@ -0,0 +1,169 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInput,
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContent,
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseMessage,
OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseText,
)
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIDeveloperMessageParam,
OpenAIImageURL,
OpenAIJSONSchema,
OpenAIMessageParam,
OpenAIResponseFormatJSONObject,
OpenAIResponseFormatJSONSchema,
OpenAIResponseFormatParam,
OpenAIResponseFormatText,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
"""Convert an OpenAI Chat Completion choice into an OpenAI Response output message."""
output_content = ""
if isinstance(choice.message.content, str):
output_content = choice.message.content
elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam):
output_content = choice.message.content.text
else:
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
)
return OpenAIResponseMessage(
id=f"msg_{uuid.uuid4()}",
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
status="completed",
role="assistant",
)
async def convert_response_content_to_chat_content(
content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
) -> str | list[OpenAIChatCompletionContentPartParam]:
"""
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
The content schemas of each API look similar, but are not exactly the same.
"""
if isinstance(content, str):
return content
converted_parts = []
for content_part in content:
if isinstance(content_part, OpenAIResponseInputMessageContentText):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
elif isinstance(content_part, OpenAIResponseInputMessageContentImage):
if content_part.image_url:
image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail)
converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
elif isinstance(content_part, str):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part))
else:
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support content type '{type(content_part)}' in this context"
)
return converted_parts
async def convert_response_input_to_chat_messages(
input: str | list[OpenAIResponseInput],
) -> list[OpenAIMessageParam]:
"""
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
"""
messages: list[OpenAIMessageParam] = []
if isinstance(input, list):
for input_item in input:
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
messages.append(
OpenAIToolMessageParam(
content=input_item.output,
tool_call_id=input_item.call_id,
)
)
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id=input_item.call_id,
function=OpenAIChatCompletionToolCallFunction(
name=input_item.name,
arguments=input_item.arguments,
),
)
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
else:
content = await convert_response_content_to_chat_content(input_item.content)
message_type = await get_message_type_by_role(input_item.role)
if message_type is None:
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
)
messages.append(message_type(content=content))
else:
messages.append(OpenAIUserMessageParam(content=input))
return messages
async def convert_response_text_to_chat_response_format(
text: OpenAIResponseText,
) -> OpenAIResponseFormatParam:
"""
Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format.
"""
if not text.format or text.format["type"] == "text":
return OpenAIResponseFormatText(type="text")
if text.format["type"] == "json_object":
return OpenAIResponseFormatJSONObject()
if text.format["type"] == "json_schema":
return OpenAIResponseFormatJSONSchema(
json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"])
)
raise ValueError(f"Unsupported text format: {text.format}")
async def get_message_type_by_role(role: str):
role_to_type = {
"user": OpenAIUserMessageParam,
"system": OpenAISystemMessageParam,
"assistant": OpenAIAssistantMessageParam,
"developer": OpenAIDeveloperMessageParam,
}
return role_to_type.get(role)
def is_function_tool_call(
tool_call: OpenAIChatCompletionToolCall,
tools: list[OpenAIResponseInputTool],
) -> bool:
if not tool_call.function:
return False
for t in tools:
if t.type == "function" and t.name == tool_call.function.name:
return True
return False

View file

@ -22,7 +22,7 @@ from llama_stack.apis.safety import (
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults, OpenAICategories
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
from llama_stack.apis.shields import Shield
from llama_stack.core.datatypes import Api
from llama_stack.models.llama.datatypes import Role
@ -72,30 +72,6 @@ SAFETY_CATEGORIES_TO_CODE_MAP = {
}
SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()}
OPENAI_TO_LLAMA_CATEGORIES_MAP = {
OpenAICategories.VIOLENCE: [CAT_VIOLENT_CRIMES],
OpenAICategories.VIOLENCE_GRAPHIC: [CAT_VIOLENT_CRIMES],
OpenAICategories.HARRASMENT: [CAT_CHILD_EXPLOITATION],
OpenAICategories.HARRASMENT_THREATENING: [CAT_VIOLENT_CRIMES, CAT_CHILD_EXPLOITATION],
OpenAICategories.HATE: [CAT_HATE],
OpenAICategories.HATE_THREATENING: [CAT_HATE, CAT_VIOLENT_CRIMES],
OpenAICategories.ILLICIT: [CAT_NON_VIOLENT_CRIMES],
OpenAICategories.ILLICIT_VIOLENT: [CAT_VIOLENT_CRIMES, CAT_INDISCRIMINATE_WEAPONS],
OpenAICategories.SEXUAL: [CAT_SEX_CRIMES, CAT_SEXUAL_CONTENT],
OpenAICategories.SEXUAL_MINORS: [CAT_CHILD_EXPLOITATION],
OpenAICategories.SELF_HARM: [CAT_SELF_HARM],
OpenAICategories.SELF_HARM_INTENT: [CAT_SELF_HARM],
OpenAICategories.SELF_HARM_INSTRUCTIONS: [CAT_SELF_HARM, CAT_SPECIALIZED_ADVICE],
# These are custom categories that are not in the OpenAI moderation categories
"custom/defamation": [CAT_DEFAMATION],
"custom/specialized_advice": [CAT_SPECIALIZED_ADVICE],
"custom/privacy_violation": [CAT_PRIVACY],
"custom/intellectual_property": [CAT_INTELLECTUAL_PROPERTY],
"custom/weapons": [CAT_INDISCRIMINATE_WEAPONS],
"custom/elections": [CAT_ELECTIONS],
"custom/code_interpreter_abuse": [CAT_CODE_INTERPRETER_ABUSE],
}
DEFAULT_LG_V3_SAFETY_CATEGORIES = [
CAT_VIOLENT_CRIMES,
@ -424,9 +400,9 @@ class LlamaGuardShield:
ModerationObject with appropriate configuration
"""
# Set default values for safe case
categories = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), False)
category_scores = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), 1.0)
category_applied_input_types = {key: [] for key in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()}
categories = dict.fromkeys(SAFETY_CATEGORIES_TO_CODE_MAP.keys(), False)
category_scores = dict.fromkeys(SAFETY_CATEGORIES_TO_CODE_MAP.keys(), 1.0)
category_applied_input_types = {key: [] for key in SAFETY_CATEGORIES_TO_CODE_MAP.keys()}
flagged = False
user_message = None
metadata = {}
@ -453,19 +429,15 @@ class LlamaGuardShield:
],
)
# Get OpenAI categories for the unsafe codes
openai_categories = []
for code in unsafe_code_list:
llama_guard_category = SAFETY_CODE_TO_CATEGORIES_MAP[code]
openai_categories.extend(
k for k, v_l in OPENAI_TO_LLAMA_CATEGORIES_MAP.items() if llama_guard_category in v_l
)
llama_guard_category = [SAFETY_CODE_TO_CATEGORIES_MAP[code] for code in unsafe_code_list]
# Update categories for unsafe content
categories = {k: k in openai_categories for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
category_scores = {k: 1.0 if k in openai_categories else 0.0 for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
categories = {k: k in llama_guard_category for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys()}
category_scores = {
k: 1.0 if k in llama_guard_category else 0.0 for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys()
}
category_applied_input_types = {
k: ["text"] if k in openai_categories else [] for k in OPENAI_TO_LLAMA_CATEGORIES_MAP
k: ["text"] if k in llama_guard_category else [] for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys()
}
flagged = True
user_message = CANNED_RESPONSE_TEXT

View file

@ -15,8 +15,10 @@ from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
SafetyViolation,
ShieldStore,
ViolationLevel,
)
from llama_stack.apis.safety.safety import ModerationObject
from llama_stack.apis.shields import Shield
from llama_stack.core.utils.model_utils import model_local_dir
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
@ -32,6 +34,8 @@ PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
shield_store: ShieldStore
def __init__(self, config: PromptGuardConfig, _deps) -> None:
self.config = config
@ -53,7 +57,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
self,
shield_id: str,
messages: list[Message],
params: dict[str, Any] = None,
params: dict[str, Any],
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:
@ -61,6 +65,9 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
return await self.shield.run(messages)
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
raise NotImplementedError("run_moderation is not implemented for Prompt Guard")
class PromptGuardShield:
def __init__(
@ -117,8 +124,10 @@ class PromptGuardShield:
elif self.config.guard_type == PromptGuardType.jailbreak.value and score_malicious > self.threshold:
violation = SafetyViolation(
violation_level=ViolationLevel.ERROR,
violation_type=f"prompt_injection:malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.",
user_message="Sorry, I cannot do this.",
metadata={
"violation_type": f"prompt_injection:malicious={score_malicious}",
},
)
return RunShieldResponse(violation=violation)

View file

@ -33,6 +33,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
@ -128,11 +129,12 @@ class FaissIndex(EmbeddingIndex):
# Save updated index
await self._save_index()
async def delete_chunk(self, chunk_id: str) -> None:
if chunk_id not in self.chunk_ids:
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
if not set(chunk_ids).issubset(self.chunk_ids):
return
async with self.chunk_id_lock:
def remove_chunk(chunk_id: str):
index = self.chunk_ids.index(chunk_id)
self.index.remove_ids(np.array([index]))
@ -146,6 +148,10 @@ class FaissIndex(EmbeddingIndex):
self.chunk_by_index = new_chunk_by_index
self.chunk_ids.pop(index)
async with self.chunk_id_lock:
for chunk_id in chunk_ids:
remove_chunk(chunk_id)
await self._save_index()
async def query_vector(
@ -174,7 +180,9 @@ class FaissIndex(EmbeddingIndex):
k: int,
score_threshold: float,
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in FAISS")
raise NotImplementedError(
"Keyword search is not supported - underlying DB FAISS does not support this search mode"
)
async def query_hybrid(
self,
@ -185,7 +193,9 @@ class FaissIndex(EmbeddingIndex):
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in FAISS")
raise NotImplementedError(
"Hybrid search is not supported - underlying DB FAISS does not support this search mode"
)
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
@ -293,8 +303,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
"""Delete a chunk from a faiss index"""
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from a faiss index"""
faiss_index = self.cache[store_id].index
for chunk_id in chunk_ids:
await faiss_index.delete_chunk(chunk_id)
await faiss_index.delete_chunks(chunks_for_deletion)

View file

@ -31,6 +31,7 @@ from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIV
from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_RRF,
RERANKER_TYPE_WEIGHTED,
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
@ -426,34 +427,36 @@ class SQLiteVecIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete_chunk(self, chunk_id: str) -> None:
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Remove a chunk from the SQLite vector store."""
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
def _delete_chunk():
def _delete_chunks():
connection = _create_sqlite_connection(self.db_path)
cur = connection.cursor()
try:
cur.execute("BEGIN TRANSACTION")
# Delete from metadata table
cur.execute(f"DELETE FROM {self.metadata_table} WHERE id = ?", (chunk_id,))
placeholders = ",".join("?" * len(chunk_ids))
cur.execute(f"DELETE FROM {self.metadata_table} WHERE id IN ({placeholders})", chunk_ids)
# Delete from vector table
cur.execute(f"DELETE FROM {self.vector_table} WHERE id = ?", (chunk_id,))
cur.execute(f"DELETE FROM {self.vector_table} WHERE id IN ({placeholders})", chunk_ids)
# Delete from FTS table
cur.execute(f"DELETE FROM {self.fts_table} WHERE id = ?", (chunk_id,))
cur.execute(f"DELETE FROM {self.fts_table} WHERE id IN ({placeholders})", chunk_ids)
connection.commit()
except Exception as e:
connection.rollback()
logger.error(f"Error deleting chunk {chunk_id}: {e}")
logger.error(f"Error deleting chunks: {e}")
raise
finally:
cur.close()
connection.close()
await asyncio.to_thread(_delete_chunk)
await asyncio.to_thread(_delete_chunks)
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
@ -551,12 +554,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
raise VectorStoreNotFoundError(vector_db_id)
return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
"""Delete a chunk from a sqlite_vec index."""
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from a sqlite_vec index."""
index = await self._get_and_cache_vector_db_index(store_id)
if not index:
raise VectorStoreNotFoundError(store_id)
for chunk_id in chunk_ids:
# Use the index's delete_chunk method
await index.index.delete_chunk(chunk_id)
await index.index.delete_chunks(chunks_for_deletion)

View file

@ -213,6 +213,36 @@ def available_providers() -> list[ProviderSpec]:
description="Google Gemini inference provider for accessing Gemini models and Google's AI services.",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="vertexai",
pip_packages=["litellm", "google-cloud-aiplatform"],
module="llama_stack.providers.remote.inference.vertexai",
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",
description="""Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages:
Enterprise-grade security: Uses Google Cloud's security controls and IAM
Better integration: Seamless integration with other Google Cloud services
Advanced features: Access to additional Vertex AI features like model tuning and monitoring
Authentication: Uses Google Cloud Application Default Credentials (ADC) instead of API keys
Configuration:
- Set VERTEX_AI_PROJECT environment variable (required)
- Set VERTEX_AI_LOCATION environment variable (optional, defaults to us-central1)
- Use Google Cloud Application Default Credentials or service account key
Authentication Setup:
Option 1 (Recommended): gcloud auth application-default login
Option 2: Set GOOGLE_APPLICATION_CREDENTIALS to service account key path
Available Models:
- vertex_ai/gemini-2.0-flash
- vertex_ai/gemini-2.5-flash
- vertex_ai/gemini-2.5-pro""",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(

View file

@ -45,6 +45,18 @@ That means you'll get fast and efficient vector retrieval.
- Lightweight and easy to use
- Fully integrated with Llama Stack
- GPU support
- **Vector search** - FAISS supports pure vector similarity search using embeddings
## Search Modes
**Supported:**
- **Vector Search** (`mode="vector"`): Performs vector similarity search using embeddings
**Not Supported:**
- **Keyword Search** (`mode="keyword"`): Not supported by FAISS
- **Hybrid Search** (`mode="hybrid"`): Not supported by FAISS
> **Note**: FAISS is designed as a pure vector similarity search library. See the [FAISS GitHub repository](https://github.com/facebookresearch/faiss) for more details about FAISS's core functionality.
## Usage
@ -330,6 +342,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
""",
),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
InlineProviderSpec(
api=Api.vector_io,
@ -338,6 +351,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
module="llama_stack.providers.inline.vector_io.chroma",
config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description="""
[Chroma](https://www.trychroma.com/) is an inline and remote vector
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
@ -452,6 +466,7 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more
""",
),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
InlineProviderSpec(
api=Api.vector_io,
@ -535,6 +550,7 @@ That means you're not limited to storing vectors in memory or in a separate serv
- Easy to use
- Fully integrated with Llama Stack
- Supports all search modes: vector, keyword, and hybrid search (both inline and remote configurations)
## Usage
@ -625,6 +641,92 @@ vector_io:
- **`client_pem_path`**: Path to the **client certificate** file (required for mTLS).
- **`client_key_path`**: Path to the **client private key** file (required for mTLS).
## Search Modes
Milvus supports three different search modes for both inline and remote configurations:
### Vector Search
Vector search uses semantic similarity to find the most relevant chunks based on embedding vectors. This is the default search mode and works well for finding conceptually similar content.
```python
# Vector search example
search_response = client.vector_stores.search(
vector_store_id=vector_store.id,
query="What is machine learning?",
search_mode="vector",
max_num_results=5,
)
```
### Keyword Search
Keyword search uses traditional text-based matching to find chunks containing specific terms or phrases. This is useful when you need exact term matches.
```python
# Keyword search example
search_response = client.vector_stores.search(
vector_store_id=vector_store.id,
query="Python programming language",
search_mode="keyword",
max_num_results=5,
)
```
### Hybrid Search
Hybrid search combines both vector and keyword search methods to provide more comprehensive results. It leverages the strengths of both semantic similarity and exact term matching.
#### Basic Hybrid Search
```python
# Basic hybrid search example (uses RRF ranker with default impact_factor=60.0)
search_response = client.vector_stores.search(
vector_store_id=vector_store.id,
query="neural networks in Python",
search_mode="hybrid",
max_num_results=5,
)
```
**Note**: The default `impact_factor` value of 60.0 was empirically determined to be optimal in the original RRF research paper: ["Reciprocal Rank Fusion outperforms Condorcet and individual Rank Learning Methods"](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf) (Cormack et al., 2009).
#### Hybrid Search with RRF (Reciprocal Rank Fusion) Ranker
RRF combines rankings from vector and keyword search by using reciprocal ranks. The impact factor controls how much weight is given to higher-ranked results.
```python
# Hybrid search with custom RRF parameters
search_response = client.vector_stores.search(
vector_store_id=vector_store.id,
query="neural networks in Python",
search_mode="hybrid",
max_num_results=5,
ranking_options={
"ranker": {
"type": "rrf",
"impact_factor": 100.0, # Higher values give more weight to top-ranked results
}
},
)
```
#### Hybrid Search with Weighted Ranker
Weighted ranker linearly combines normalized scores from vector and keyword search. The alpha parameter controls the balance between the two search methods.
```python
# Hybrid search with weighted ranker
search_response = client.vector_stores.search(
vector_store_id=vector_store.id,
query="neural networks in Python",
search_mode="hybrid",
max_num_results=5,
ranking_options={
"ranker": {
"type": "weighted",
"alpha": 0.7, # 70% vector search, 30% keyword search
}
},
)
```
For detailed documentation on RRF and Weighted rankers, please refer to the [Milvus Reranking Guide](https://milvus.io/docs/reranking.md).
## Documentation
See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general.
@ -632,6 +734,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
""",
),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
InlineProviderSpec(
api=Api.vector_io,

View file

@ -235,6 +235,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
# TODO: tools are never added to the request, so we need to add them here
if media_present or not llama_model:
input_dict["messages"] = [
await convert_message_to_openai_dict(m, download=True) for m in request.messages
@ -378,6 +379,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
# Fireworks chat completions OpenAI-compatible API does not support
# tool calls properly.
llama_model = self.get_llama_model(model_obj.provider_resource_id)
if llama_model:
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(
self,
@ -431,4 +433,5 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
user=user,
)
logger.debug(f"fireworks params: {params}")
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)

View file

@ -457,9 +457,6 @@ class OllamaInferenceAdapter(
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
model_obj = await self._get_model(model)
if model_obj.model_type != ModelType.embedding:
raise ValueError(f"Model {model} is not an embedding model")
if model_obj.provider_resource_id is None:
raise ValueError(f"Model {model} has no provider_resource_id set")

View file

@ -308,9 +308,7 @@ class TGIAdapter(_HfAdapter):
if not config.url:
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
log.info(f"Initializing TGI client with url={config.url}")
self.client = AsyncInferenceClient(
model=config.url,
)
self.client = AsyncInferenceClient(model=config.url, provider="hf-inference")
endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import VertexAIConfig
async def get_adapter_impl(config: VertexAIConfig, _deps):
from .vertexai import VertexAIInferenceAdapter
impl = VertexAIInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,45 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class VertexAIProviderDataValidator(BaseModel):
vertex_project: str | None = Field(
default=None,
description="Google Cloud project ID for Vertex AI",
)
vertex_location: str | None = Field(
default=None,
description="Google Cloud location for Vertex AI (e.g., us-central1)",
)
@json_schema_type
class VertexAIConfig(BaseModel):
project: str = Field(
description="Google Cloud project ID for Vertex AI",
)
location: str = Field(
default="us-central1",
description="Google Cloud location for Vertex AI",
)
@classmethod
def sample_run_config(
cls,
project: str = "${env.VERTEX_AI_PROJECT:=}",
location: str = "${env.VERTEX_AI_LOCATION:=us-central1}",
**kwargs,
) -> dict[str, Any]:
return {
"project": project,
"location": location,
}

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)
# Vertex AI model IDs with vertex_ai/ prefix as required by litellm
LLM_MODEL_IDS = [
"vertex_ai/gemini-2.0-flash",
"vertex_ai/gemini-2.5-flash",
"vertex_ai/gemini-2.5-pro",
]
SAFETY_MODELS_ENTRIES = list[ProviderModelEntry]()
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + SAFETY_MODELS_ENTRIES

View file

@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.inference import ChatCompletionRequest
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
LiteLLMOpenAIMixin,
)
from .config import VertexAIConfig
from .models import MODEL_ENTRIES
class VertexAIInferenceAdapter(LiteLLMOpenAIMixin):
def __init__(self, config: VertexAIConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
MODEL_ENTRIES,
litellm_provider_name="vertex_ai",
api_key_from_config=None, # Vertex AI uses ADC, not API keys
provider_data_api_key_field="vertex_project", # Use project for validation
)
self.config = config
def get_api_key(self) -> str:
# Vertex AI doesn't use API keys, it uses Application Default Credentials
# Return empty string to let litellm handle authentication via ADC
return ""
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
# Get base parameters from parent
params = await super()._get_params(request)
# Add Vertex AI specific parameters
provider_data = self.get_request_provider_data()
if provider_data:
if getattr(provider_data, "vertex_project", None):
params["vertex_project"] = provider_data.vertex_project
if getattr(provider_data, "vertex_location", None):
params["vertex_location"] = provider_data.vertex_location
else:
params["vertex_project"] = self.config.project
params["vertex_location"] = self.config.location
# Remove api_key since Vertex AI uses ADC
params.pop("api_key", None)
return params

View file

@ -26,6 +26,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
@ -115,8 +116,10 @@ class ChromaIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Chroma")
async def delete_chunk(self, chunk_id: str) -> None:
raise NotImplementedError("delete_chunk is not supported in Chroma")
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete a single chunk from the Chroma collection by its ID."""
ids = [f"{chunk.document_id}:{chunk.chunk_id}" for chunk in chunks_for_deletion]
await maybe_await(self.collection.delete(ids=ids))
async def query_hybrid(
self,
@ -144,6 +147,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.cache = {}
self.kvstore: KVStore | None = None
self.vector_db_store = None
self.files_api = files_api
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
@ -227,5 +231,10 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.cache[vector_db_id] = index
return index
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from a Chroma vector store."""
index = await self._get_and_cache_vector_db_index(store_id)
if not index:
raise ValueError(f"Vector DB {store_id} not found")
await index.index.delete_chunks(chunks_for_deletion)

View file

@ -28,6 +28,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_WEIGHTED,
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
@ -287,14 +288,17 @@ class MilvusIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
async def delete_chunk(self, chunk_id: str) -> None:
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Remove a chunk from the Milvus collection."""
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
try:
# Use IN clause with square brackets and single quotes for VARCHAR field
chunk_ids_str = ", ".join(f"'{chunk_id}'" for chunk_id in chunk_ids)
await asyncio.to_thread(
self.client.delete, collection_name=self.collection_name, filter=f'chunk_id == "{chunk_id}"'
self.client.delete, collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]"
)
except Exception as e:
logger.error(f"Error deleting chunk {chunk_id} from Milvus collection {self.collection_name}: {e}")
logger.error(f"Error deleting chunks from Milvus collection {self.collection_name}: {e}")
raise
@ -420,12 +424,10 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete a chunk from a milvus vector store."""
index = await self._get_and_cache_vector_db_index(store_id)
if not index:
raise VectorStoreNotFoundError(store_id)
for chunk_id in chunk_ids:
# Use the index's delete_chunk method
await index.index.delete_chunk(chunk_id)
await index.index.delete_chunks(chunks_for_deletion)

View file

@ -27,6 +27,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
@ -163,10 +164,11 @@ class PGVectorIndex(EmbeddingIndex):
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
async def delete_chunk(self, chunk_id: str) -> None:
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Remove a chunk from the PostgreSQL table."""
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id = %s", (chunk_id,))
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids,))
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
@ -275,12 +277,10 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
return self.cache[vector_db_id]
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete a chunk from a PostgreSQL vector store."""
index = await self._get_and_cache_vector_db_index(store_id)
if not index:
raise VectorStoreNotFoundError(store_id)
for chunk_id in chunk_ids:
# Use the index's delete_chunk method
await index.index.delete_chunk(chunk_id)
await index.index.delete_chunks(chunks_for_deletion)

View file

@ -29,6 +29,7 @@ from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig a
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
@ -88,15 +89,16 @@ class QdrantIndex(EmbeddingIndex):
await self.client.upsert(collection_name=self.collection_name, points=points)
async def delete_chunk(self, chunk_id: str) -> None:
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Remove a chunk from the Qdrant collection."""
chunk_ids = [convert_id(c.chunk_id) for c in chunks_for_deletion]
try:
await self.client.delete(
collection_name=self.collection_name,
points_selector=models.PointIdsList(points=[convert_id(chunk_id)]),
points_selector=models.PointIdsList(points=chunk_ids),
)
except Exception as e:
log.error(f"Error deleting chunk {chunk_id} from Qdrant collection {self.collection_name}: {e}")
log.error(f"Error deleting chunks from Qdrant collection {self.collection_name}: {e}")
raise
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
@ -264,12 +266,14 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
) -> VectorStoreFileObject:
# Qdrant doesn't allow multiple clients to access the same storage path simultaneously.
async with self._qdrant_lock:
await super().openai_attach_file_to_vector_store(vector_store_id, file_id, attributes, chunking_strategy)
return await super().openai_attach_file_to_vector_store(
vector_store_id, file_id, attributes, chunking_strategy
)
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from a Qdrant vector store."""
index = await self._get_and_cache_vector_db_index(store_id)
if not index:
raise ValueError(f"Vector DB {store_id} not found")
for chunk_id in chunk_ids:
await index.index.delete_chunk(chunk_id)
await index.index.delete_chunks(chunks_for_deletion)

View file

@ -26,6 +26,7 @@ from llama_stack.providers.utils.memory.openai_vector_store_mixin import (
OpenAIVectorStoreMixin,
)
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorDBWithIndex,
)
@ -67,6 +68,7 @@ class WeaviateIndex(EmbeddingIndex):
data_objects.append(
wvc.data.DataObject(
properties={
"chunk_id": chunk.chunk_id,
"chunk_content": chunk.model_dump_json(),
},
vector=embeddings[i].tolist(),
@ -79,10 +81,11 @@ class WeaviateIndex(EmbeddingIndex):
# TODO: make this async friendly
collection.data.insert_many(data_objects)
async def delete_chunk(self, chunk_id: str) -> None:
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
collection = self.client.collections.get(sanitized_collection_name)
collection.data.delete_many(where=Filter.by_property("id").contains_any([chunk_id]))
chunk_ids = [chunk.chunk_id for chunk in chunks_for_deletion]
collection.data.delete_many(where=Filter.by_property("chunk_id").contains_any(chunk_ids))
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
@ -307,10 +310,10 @@ class WeaviateVectorIOAdapter(
return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
sanitized_collection_name = sanitize_collection_name(store_id, weaviate_format=True)
index = await self._get_and_cache_vector_db_index(sanitized_collection_name)
if not index:
raise ValueError(f"Vector DB {sanitized_collection_name} not found")
await index.delete(chunk_ids)
await index.index.delete_chunks(chunks_for_deletion)

View file

@ -31,15 +31,15 @@ from openai.types.chat import (
from openai.types.chat import (
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
)
from openai.types.chat import (
ChatCompletionMessageFunctionToolCall as OpenAIChatCompletionMessageFunctionToolCall,
)
from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
)
from openai.types.chat import (
ChatCompletionMessageToolCall,
)
from openai.types.chat import (
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall,
)
from openai.types.chat import (
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
)
@ -70,7 +70,7 @@ from openai.types.chat.chat_completion_chunk import (
from openai.types.chat.chat_completion_content_part_image_param import (
ImageURL as OpenAIImageURL,
)
from openai.types.chat.chat_completion_message_tool_call_param import (
from openai.types.chat.chat_completion_message_tool_call import (
Function as OpenAIFunction,
)
from pydantic import BaseModel
@ -633,7 +633,7 @@ async def convert_message_to_openai_dict_new(
)
elif isinstance(message, CompletionMessage):
tool_calls = [
OpenAIChatCompletionMessageToolCall(
OpenAIChatCompletionMessageFunctionToolCall(
id=tool.call_id,
function=OpenAIFunction(
name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
@ -903,7 +903,7 @@ def _convert_openai_request_response_format(
def _convert_openai_tool_calls(
tool_calls: list[OpenAIChatCompletionMessageToolCall],
tool_calls: list[OpenAIChatCompletionMessageFunctionToolCall],
) -> list[ToolCall]:
"""
Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall.

View file

@ -6,7 +6,6 @@
import asyncio
import json
import logging
import mimetypes
import time
import uuid
@ -37,10 +36,15 @@ from llama_stack.apis.vector_io import (
VectorStoreSearchResponse,
VectorStoreSearchResponsePage,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
content_from_data_and_mime_type,
make_overlapped_chunks,
)
logger = logging.getLogger(__name__)
logger = get_logger(__name__, category="vector_io")
# Constants for OpenAI vector stores
CHUNK_MULTIPLIER = 5
@ -154,8 +158,8 @@ class OpenAIVectorStoreMixin(ABC):
self.openai_vector_stores = await self._load_openai_vector_stores()
@abstractmethod
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
"""Delete a chunk from a vector store."""
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from a vector store."""
pass
@abstractmethod
@ -614,7 +618,7 @@ class OpenAIVectorStoreMixin(ABC):
)
vector_store_file_object.status = "completed"
except Exception as e:
logger.error(f"Error attaching file to vector store: {e}")
logger.exception("Error attaching file to vector store")
vector_store_file_object.status = "failed"
vector_store_file_object.last_error = VectorStoreFileLastError(
code="server_error",
@ -767,7 +771,21 @@ class OpenAIVectorStoreMixin(ABC):
dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id)
chunks = [Chunk.model_validate(c) for c in dict_chunks]
await self.delete_chunks(vector_store_id, [str(c.chunk_id) for c in chunks if c.chunk_id])
# Create ChunkForDeletion objects with both chunk_id and document_id
chunks_for_deletion = []
for c in chunks:
if c.chunk_id:
document_id = c.metadata.get("document_id") or (
c.chunk_metadata.document_id if c.chunk_metadata else None
)
if document_id:
chunks_for_deletion.append(ChunkForDeletion(chunk_id=str(c.chunk_id), document_id=document_id))
else:
logger.warning(f"Chunk {c.chunk_id} has no document_id, skipping deletion")
if chunks_for_deletion:
await self.delete_chunks(vector_store_id, chunks_for_deletion)
store_info = self.openai_vector_stores[vector_store_id].copy()

View file

@ -16,6 +16,7 @@ from urllib.parse import unquote
import httpx
import numpy as np
from numpy.typing import NDArray
from pydantic import BaseModel
from llama_stack.apis.common.content_types import (
URL,
@ -34,6 +35,18 @@ from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
log = logging.getLogger(__name__)
class ChunkForDeletion(BaseModel):
"""Information needed to delete a chunk from a vector store.
:param chunk_id: The ID of the chunk to delete
:param document_id: The ID of the document this chunk belongs to
"""
chunk_id: str
document_id: str
# Constants for reranker types
RERANKER_TYPE_RRF = "rrf"
RERANKER_TYPE_WEIGHTED = "weighted"
@ -232,7 +245,7 @@ class EmbeddingIndex(ABC):
raise NotImplementedError()
@abstractmethod
async def delete_chunk(self, chunk_id: str):
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]):
raise NotImplementedError()
@abstractmethod

1
llama_stack/ui/.nvmrc Normal file
View file

@ -0,0 +1 @@
22.5.1

View file

@ -1,3 +1,12 @@
# Ignore artifacts:
build
coverage
.next
node_modules
dist
*.lock
*.log
# Generated files
*.min.js
*.min.css

View file

@ -1 +1,10 @@
{}
{
"semi": true,
"trailingComma": "es5",
"singleQuote": false,
"printWidth": 80,
"tabWidth": 2,
"useTabs": false,
"bracketSpacing": true,
"arrowParens": "avoid"
}

View file

@ -47,7 +47,7 @@ async function proxyRequest(request: NextRequest, method: string) {
const responseText = await response.text();
console.log(
`Response from FastAPI: ${response.status} ${response.statusText}`,
`Response from FastAPI: ${response.status} ${response.statusText}`
);
// Create response with same status and headers
@ -74,7 +74,7 @@ async function proxyRequest(request: NextRequest, method: string) {
backend_url: BACKEND_URL,
timestamp: new Date().toISOString(),
},
{ status: 500 },
{ status: 500 }
);
}
}

View file

@ -51,9 +51,9 @@ export default function SignInPage() {
onClick={() => {
console.log("Signing in with GitHub...");
signIn("github", { callbackUrl: "/auth/signin" }).catch(
(error) => {
error => {
console.error("Sign in error:", error);
},
}
);
}}
className="w-full"

View file

@ -29,14 +29,13 @@ export default function ChatPlaygroundPage() {
const isModelsLoading = modelsLoading ?? true;
useEffect(() => {
const fetchModels = async () => {
try {
setModelsLoading(true);
setModelsError(null);
const modelList = await client.models.list();
const llmModels = modelList.filter(model => model.model_type === 'llm');
const llmModels = modelList.filter(model => model.model_type === "llm");
setModels(llmModels);
if (llmModels.length > 0) {
setSelectedModel(llmModels[0].identifier);
@ -53,26 +52,42 @@ export default function ChatPlaygroundPage() {
}, [client]);
const extractTextContent = (content: unknown): string => {
if (typeof content === 'string') {
if (typeof content === "string") {
return content;
}
if (Array.isArray(content)) {
return content
.filter(item => item && typeof item === 'object' && 'type' in item && item.type === 'text')
.map(item => (item && typeof item === 'object' && 'text' in item) ? String(item.text) : '')
.join('');
.filter(
item =>
item &&
typeof item === "object" &&
"type" in item &&
item.type === "text"
)
.map(item =>
item && typeof item === "object" && "text" in item
? String(item.text)
: ""
)
.join("");
}
if (content && typeof content === 'object' && 'type' in content && content.type === 'text' && 'text' in content) {
return String(content.text) || '';
if (
content &&
typeof content === "object" &&
"type" in content &&
content.type === "text" &&
"text" in content
) {
return String(content.text) || "";
}
return '';
return "";
};
const handleInputChange = (e: React.ChangeEvent<HTMLTextAreaElement>) => {
setInput(e.target.value);
};
const handleSubmit = async (event?: { preventDefault?: () => void }) => {
const handleSubmit = async (event?: { preventDefault?: () => void }) => {
event?.preventDefault?.();
if (!input.trim()) return;
@ -89,16 +104,19 @@ const handleSubmit = async (event?: { preventDefault?: () => void }) => {
// Use the helper function with the content
await handleSubmitWithContent(userMessage.content);
};
};
const handleSubmitWithContent = async (content: string) => {
const handleSubmitWithContent = async (content: string) => {
setIsGenerating(true);
setError(null);
try {
const messageParams: CompletionCreateParams["messages"] = [
...messages.map(msg => {
const msgContent = typeof msg.content === 'string' ? msg.content : extractTextContent(msg.content);
const msgContent =
typeof msg.content === "string"
? msg.content
: extractTextContent(msg.content);
if (msg.role === "user") {
return { role: "user" as const, content: msgContent };
} else if (msg.role === "assistant") {
@ -107,7 +125,7 @@ const handleSubmitWithContent = async (content: string) => {
return { role: "system" as const, content: msgContent };
}
}),
{ role: "user" as const, content }
{ role: "user" as const, content },
];
const response = await client.chat.completions.create({
@ -149,7 +167,7 @@ const handleSubmitWithContent = async (content: string) => {
} finally {
setIsGenerating(false);
}
};
};
const suggestions = [
"Write a Python function that prints 'Hello, World!'",
"Explain step-by-step how to solve this math problem: If x² + 6x + 9 = 25, what is x?",
@ -163,7 +181,7 @@ const handleSubmitWithContent = async (content: string) => {
content: message.content,
createdAt: new Date(),
};
setMessages(prev => [...prev, newMessage])
setMessages(prev => [...prev, newMessage]);
handleSubmitWithContent(newMessage.content);
};
@ -175,14 +193,22 @@ const handleSubmitWithContent = async (content: string) => {
return (
<div className="flex flex-col h-full max-w-4xl mx-auto">
<div className="mb-4 flex justify-between items-center">
<h1 className="text-2xl font-bold">Chat Playground</h1>
<h1 className="text-2xl font-bold">Chat Playground (Completions)</h1>
<div className="flex gap-2">
<Select value={selectedModel} onValueChange={setSelectedModel} disabled={isModelsLoading || isGenerating}>
<Select
value={selectedModel}
onValueChange={setSelectedModel}
disabled={isModelsLoading || isGenerating}
>
<SelectTrigger className="w-[180px]">
<SelectValue placeholder={isModelsLoading ? "Loading models..." : "Select Model"} />
<SelectValue
placeholder={
isModelsLoading ? "Loading models..." : "Select Model"
}
/>
</SelectTrigger>
<SelectContent>
{models.map((model) => (
{models.map(model => (
<SelectItem key={model.identifier} value={model.identifier}>
{model.identifier}
</SelectItem>

View file

@ -33,12 +33,12 @@ export default function ChatCompletionDetailPage() {
} catch (err) {
console.error(
`Error fetching chat completion detail for ID ${id}:`,
err,
err
);
setError(
err instanceof Error
? err
: new Error("Failed to fetch completion detail"),
: new Error("Failed to fetch completion detail")
);
} finally {
setIsLoading(false);

View file

@ -13,10 +13,10 @@ export default function ResponseDetailPage() {
const client = useAuthClient();
const [responseDetail, setResponseDetail] = useState<OpenAIResponse | null>(
null,
null
);
const [inputItems, setInputItems] = useState<InputItemListResponse | null>(
null,
null
);
const [isLoading, setIsLoading] = useState<boolean>(true);
const [isLoadingInputItems, setIsLoadingInputItems] = useState<boolean>(true);
@ -25,7 +25,7 @@ export default function ResponseDetailPage() {
// Helper function to convert ResponseObject to OpenAIResponse
const convertResponseObject = (
responseData: ResponseObject,
responseData: ResponseObject
): OpenAIResponse => {
return {
id: responseData.id,
@ -73,12 +73,12 @@ export default function ResponseDetailPage() {
} else {
console.error(
`Error fetching response detail for ID ${id}:`,
responseResult.reason,
responseResult.reason
);
setError(
responseResult.reason instanceof Error
? responseResult.reason
: new Error("Failed to fetch response detail"),
: new Error("Failed to fetch response detail")
);
}
@ -90,18 +90,18 @@ export default function ResponseDetailPage() {
} else {
console.error(
`Error fetching input items for response ID ${id}:`,
inputItemsResult.reason,
inputItemsResult.reason
);
setInputItemsError(
inputItemsResult.reason instanceof Error
? inputItemsResult.reason
: new Error("Failed to fetch input items"),
: new Error("Failed to fetch input items")
);
}
} catch (err) {
console.error(`Unexpected error fetching data for ID ${id}:`, err);
setError(
err instanceof Error ? err : new Error("Unexpected error occurred"),
err instanceof Error ? err : new Error("Unexpected error occurred")
);
} finally {
setIsLoading(false);

View file

@ -18,7 +18,10 @@ import {
PropertiesCard,
PropertyItem,
} from "@/components/layout/detail-layout";
import { PageBreadcrumb, BreadcrumbSegment } from "@/components/layout/page-breadcrumb";
import {
PageBreadcrumb,
BreadcrumbSegment,
} from "@/components/layout/page-breadcrumb";
export default function ContentDetailPage() {
const params = useParams();
@ -28,13 +31,13 @@ export default function ContentDetailPage() {
const contentId = params.contentId as string;
const client = useAuthClient();
const getTextFromContent = (content: any): string => {
if (typeof content === 'string') {
const getTextFromContent = (content: unknown): string => {
if (typeof content === "string") {
return content;
} else if (content && content.type === 'text') {
} else if (content && content.type === "text") {
return content.text;
}
return '';
return "";
};
const [store, setStore] = useState<VectorStore | null>(null);
@ -44,7 +47,9 @@ export default function ContentDetailPage() {
const [error, setError] = useState<Error | null>(null);
const [isEditing, setIsEditing] = useState(false);
const [editedContent, setEditedContent] = useState("");
const [editedMetadata, setEditedMetadata] = useState<Record<string, any>>({});
const [editedMetadata, setEditedMetadata] = useState<Record<string, unknown>>(
{}
);
const [isEditingEmbedding, setIsEditingEmbedding] = useState(false);
const [editedEmbedding, setEditedEmbedding] = useState<number[]>([]);
@ -64,8 +69,13 @@ export default function ContentDetailPage() {
setFile(fileResponse as VectorStoreFile);
const contentsAPI = new ContentsAPI(client);
const contentsResponse = await contentsAPI.listContents(vectorStoreId, fileId);
const targetContent = contentsResponse.data.find(c => c.id === contentId);
const contentsResponse = await contentsAPI.listContents(
vectorStoreId,
fileId
);
const targetContent = contentsResponse.data.find(
c => c.id === contentId
);
if (targetContent) {
setContent(targetContent);
@ -76,7 +86,9 @@ export default function ContentDetailPage() {
throw new Error(`Content ${contentId} not found`);
}
} catch (err) {
setError(err instanceof Error ? err : new Error("Failed to load content."));
setError(
err instanceof Error ? err : new Error("Failed to load content.")
);
} finally {
setIsLoading(false);
}
@ -88,7 +100,8 @@ export default function ContentDetailPage() {
if (!content) return;
try {
const updates: { content?: string; metadata?: Record<string, any> } = {};
const updates: { content?: string; metadata?: Record<string, unknown> } =
{};
if (editedContent !== getTextFromContent(content.content)) {
updates.content = editedContent;
@ -100,25 +113,32 @@ export default function ContentDetailPage() {
if (Object.keys(updates).length > 0) {
const contentsAPI = new ContentsAPI(client);
const updatedContent = await contentsAPI.updateContent(vectorStoreId, fileId, contentId, updates);
const updatedContent = await contentsAPI.updateContent(
vectorStoreId,
fileId,
contentId,
updates
);
setContent(updatedContent);
}
setIsEditing(false);
} catch (err) {
console.error('Failed to update content:', err);
console.error("Failed to update content:", err);
}
};
const handleDelete = async () => {
if (!confirm('Are you sure you want to delete this content?')) return;
if (!confirm("Are you sure you want to delete this content?")) return;
try {
const contentsAPI = new ContentsAPI(client);
await contentsAPI.deleteContent(vectorStoreId, fileId, contentId);
router.push(`/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents`);
router.push(
`/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents`
);
} catch (err) {
console.error('Failed to delete content:', err);
console.error("Failed to delete content:", err);
}
};
@ -134,10 +154,19 @@ export default function ContentDetailPage() {
const breadcrumbSegments: BreadcrumbSegment[] = [
{ label: "Vector Stores", href: "/logs/vector-stores" },
{ label: store?.name || vectorStoreId, href: `/logs/vector-stores/${vectorStoreId}` },
{
label: store?.name || vectorStoreId,
href: `/logs/vector-stores/${vectorStoreId}`,
},
{ label: "Files", href: `/logs/vector-stores/${vectorStoreId}` },
{ label: fileId, href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}` },
{ label: "Contents", href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents` },
{
label: fileId,
href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}`,
},
{
label: "Contents",
href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents`,
},
{ label: contentId },
];
@ -186,7 +215,7 @@ export default function ContentDetailPage() {
{isEditing ? (
<textarea
value={editedContent}
onChange={(e) => setEditedContent(e.target.value)}
onChange={e => setEditedContent(e.target.value)}
className="w-full h-64 p-3 border rounded-md resize-none font-mono text-sm"
placeholder="Enter content..."
/>
@ -206,16 +235,23 @@ export default function ContentDetailPage() {
<div className="flex gap-2">
{isEditingEmbedding ? (
<>
<Button size="sm" onClick={() => {
<Button
size="sm"
onClick={() => {
setIsEditingEmbedding(false);
}}>
}}
>
<Save className="h-4 w-4 mr-1" />
Save
</Button>
<Button size="sm" variant="outline" onClick={() => {
<Button
size="sm"
variant="outline"
onClick={() => {
setEditedEmbedding(content?.embedding || []);
setIsEditingEmbedding(false);
}}>
}}
>
<X className="h-4 w-4 mr-1" />
Cancel
</Button>
@ -237,14 +273,16 @@ export default function ContentDetailPage() {
</p>
<textarea
value={JSON.stringify(editedEmbedding, null, 2)}
onChange={(e) => {
onChange={e => {
try {
const parsed = JSON.parse(e.target.value);
if (Array.isArray(parsed) && parsed.every(v => typeof v === 'number')) {
if (
Array.isArray(parsed) &&
parsed.every(v => typeof v === "number")
) {
setEditedEmbedding(parsed);
}
} catch {
}
} catch {}
}}
className="w-full h-32 p-3 border rounded-md resize-none font-mono text-xs"
placeholder="Enter embedding as JSON array..."
@ -259,8 +297,15 @@ export default function ContentDetailPage() {
</div>
<div className="p-3 bg-gray-50 dark:bg-gray-800 rounded-md max-h-32 overflow-y-auto">
<pre className="whitespace-pre-wrap font-mono text-xs text-gray-900 dark:text-gray-100">
[{content.embedding.slice(0, 20).map(v => v.toFixed(6)).join(', ')}
{content.embedding.length > 20 ? `\n... and ${content.embedding.length - 20} more values` : ''}]
[
{content.embedding
.slice(0, 20)
.map(v => v.toFixed(6))
.join(", ")}
{content.embedding.length > 20
? `\n... and ${content.embedding.length - 20} more values`
: ""}
]
</pre>
</div>
</div>
@ -284,7 +329,7 @@ export default function ContentDetailPage() {
<div key={key} className="flex gap-2">
<Input
value={key}
onChange={(e) => {
onChange={e => {
const newMetadata = { ...editedMetadata };
delete newMetadata[key];
newMetadata[e.target.value] = value;
@ -294,11 +339,13 @@ export default function ContentDetailPage() {
className="flex-1"
/>
<Input
value={typeof value === 'string' ? value : JSON.stringify(value)}
onChange={(e) => {
value={
typeof value === "string" ? value : JSON.stringify(value)
}
onChange={e => {
setEditedMetadata({
...editedMetadata,
[key]: e.target.value
[key]: e.target.value,
});
}}
placeholder="Value"
@ -312,7 +359,7 @@ export default function ContentDetailPage() {
onClick={() => {
setEditedMetadata({
...editedMetadata,
['']: ''
[""]: "",
});
}}
>
@ -325,7 +372,7 @@ export default function ContentDetailPage() {
<div key={key} className="flex justify-between py-1">
<span className="font-medium text-gray-600">{key}:</span>
<span className="font-mono text-sm">
{typeof value === 'string' ? value : JSON.stringify(value)}
{typeof value === "string" ? value : JSON.stringify(value)}
</span>
</div>
))}
@ -351,15 +398,15 @@ export default function ContentDetailPage() {
value={`${getTextFromContent(content.content).length} chars`}
/>
{content.metadata.chunk_window && (
<PropertyItem
label="Position"
value={content.metadata.chunk_window}
/>
<PropertyItem label="Position" value={content.metadata.chunk_window} />
)}
{file && (
<>
<PropertyItem label="File Status" value={file.status} />
<PropertyItem label="File Usage" value={`${file.usage_bytes} bytes`} />
<PropertyItem
label="File Usage"
value={`${file.usage_bytes} bytes`}
/>
</>
)}
{store && (

View file

@ -18,7 +18,10 @@ import {
PropertiesCard,
PropertyItem,
} from "@/components/layout/detail-layout";
import { PageBreadcrumb, BreadcrumbSegment } from "@/components/layout/page-breadcrumb";
import {
PageBreadcrumb,
BreadcrumbSegment,
} from "@/components/layout/page-breadcrumb";
import {
Table,
TableBody,
@ -36,23 +39,21 @@ export default function ContentsListPage() {
const fileId = params.fileId as string;
const client = useAuthClient();
const getTextFromContent = (content: any): string => {
if (typeof content === 'string') {
const getTextFromContent = (content: unknown): string => {
if (typeof content === "string") {
return content;
} else if (content && content.type === 'text') {
} else if (content && content.type === "text") {
return content.text;
}
return '';
return "";
};
const [store, setStore] = useState<VectorStore | null>(null);
const [file, setFile] = useState<VectorStoreFile | null>(null);
const [contents, setContents] = useState<VectorStoreContentItem[]>([]);
const [isLoadingStore, setIsLoadingStore] = useState(true);
const [isLoadingFile, setIsLoadingFile] = useState(true);
const [isLoadingContents, setIsLoadingContents] = useState(true);
const [errorStore, setErrorStore] = useState<Error | null>(null);
const [errorFile, setErrorFile] = useState<Error | null>(null);
const [errorContents, setErrorContents] = useState<Error | null>(null);
useEffect(() => {
@ -65,7 +66,9 @@ export default function ContentsListPage() {
const response = await client.vectorStores.retrieve(vectorStoreId);
setStore(response as VectorStore);
} catch (err) {
setErrorStore(err instanceof Error ? err : new Error("Failed to load vector store."));
setErrorStore(
err instanceof Error ? err : new Error("Failed to load vector store.")
);
} finally {
setIsLoadingStore(false);
}
@ -80,10 +83,15 @@ export default function ContentsListPage() {
setIsLoadingFile(true);
setErrorFile(null);
try {
const response = await client.vectorStores.files.retrieve(vectorStoreId, fileId);
const response = await client.vectorStores.files.retrieve(
vectorStoreId,
fileId
);
setFile(response as VectorStoreFile);
} catch (err) {
setErrorFile(err instanceof Error ? err : new Error("Failed to load file."));
setErrorFile(
err instanceof Error ? err : new Error("Failed to load file.")
);
} finally {
setIsLoadingFile(false);
}
@ -99,10 +107,16 @@ export default function ContentsListPage() {
setErrorContents(null);
try {
const contentsAPI = new ContentsAPI(client);
const contentsResponse = await contentsAPI.listContents(vectorStoreId, fileId, { limit: 100 });
const contentsResponse = await contentsAPI.listContents(
vectorStoreId,
fileId,
{ limit: 100 }
);
setContents(contentsResponse.data);
} catch (err) {
setErrorContents(err instanceof Error ? err : new Error("Failed to load contents."));
setErrorContents(
err instanceof Error ? err : new Error("Failed to load contents.")
);
} finally {
setIsLoadingContents(false);
}
@ -116,26 +130,36 @@ export default function ContentsListPage() {
await contentsAPI.deleteContent(vectorStoreId, fileId, contentId);
setContents(contents.filter(content => content.id !== contentId));
} catch (err) {
console.error('Failed to delete content:', err);
console.error("Failed to delete content:", err);
}
};
const handleViewContent = (contentId: string) => {
router.push(`/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents/${contentId}`);
router.push(
`/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents/${contentId}`
);
};
const title = `Contents in File: ${fileId}`;
const breadcrumbSegments: BreadcrumbSegment[] = [
{ label: "Vector Stores", href: "/logs/vector-stores" },
{ label: store?.name || vectorStoreId, href: `/logs/vector-stores/${vectorStoreId}` },
{
label: store?.name || vectorStoreId,
href: `/logs/vector-stores/${vectorStoreId}`,
},
{ label: "Files", href: `/logs/vector-stores/${vectorStoreId}` },
{ label: fileId, href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}` },
{
label: fileId,
href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}`,
},
{ label: "Contents" },
];
if (errorStore) {
return <DetailErrorView title={title} id={vectorStoreId} error={errorStore} />;
return (
<DetailErrorView title={title} id={vectorStoreId} error={errorStore} />
);
}
if (isLoadingStore) {
return <DetailLoadingView title={title} />;
@ -175,7 +199,7 @@ export default function ContentsListPage() {
</TableRow>
</TableHeader>
<TableBody>
{contents.map((content) => (
{contents.map(content => (
<TableRow key={content.id}>
<TableCell className="font-mono text-xs">
<Button
@ -189,7 +213,10 @@ export default function ContentsListPage() {
</TableCell>
<TableCell>
<div className="max-w-md">
<p className="text-sm truncate" title={getTextFromContent(content.content)}>
<p
className="text-sm truncate"
title={getTextFromContent(content.content)}
>
{getTextFromContent(content.content)}
</p>
</div>
@ -197,12 +224,25 @@ export default function ContentsListPage() {
<TableCell className="text-xs text-gray-500">
{content.embedding && content.embedding.length > 0 ? (
<div className="max-w-xs">
<span className="font-mono text-xs bg-gray-100 dark:bg-gray-800 rounded px-1 py-0.5" title={`${content.embedding.length}D vector: [${content.embedding.slice(0, 3).map(v => v.toFixed(3)).join(', ')}...]`}>
[{content.embedding.slice(0, 3).map(v => v.toFixed(3)).join(', ')}...] ({content.embedding.length}D)
<span
className="font-mono text-xs bg-gray-100 dark:bg-gray-800 rounded px-1 py-0.5"
title={`${content.embedding.length}D vector: [${content.embedding
.slice(0, 3)
.map(v => v.toFixed(3))
.join(", ")}...]`}
>
[
{content.embedding
.slice(0, 3)
.map(v => v.toFixed(3))
.join(", ")}
...] ({content.embedding.length}D)
</span>
</div>
) : (
<span className="text-gray-400 dark:text-gray-500 italic">No embedding</span>
<span className="text-gray-400 dark:text-gray-500 italic">
No embedding
</span>
)}
</TableCell>
<TableCell className="text-xs text-gray-500">
@ -211,7 +251,9 @@ export default function ContentsListPage() {
: `${content.metadata.content_length || 0} chars`}
</TableCell>
<TableCell className="text-xs">
{new Date(content.created_timestamp * 1000).toLocaleString()}
{new Date(
content.created_timestamp * 1000
).toLocaleString()}
</TableCell>
<TableCell>
<div className="flex gap-1">

Some files were not shown because too many files have changed in this diff Show more