Merge branch 'main' into register_custom_model

This commit is contained in:
Rashmi Pawar 2025-04-16 14:35:51 +05:30 committed by GitHub
commit afb792b9c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
69 changed files with 8875 additions and 890 deletions

View file

@ -34,22 +34,20 @@ jobs:
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install uv
uses: astral-sh/setup-uv@22695119d769bdb6f7032ad67b9bca0ef8c4a174 # v5.4.0
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
with:
python-version: "3.10"
- name: Install Ollama
- name: Install and start Ollama
run: |
# the ollama installer also starts the ollama service
curl -fsSL https://ollama.com/install.sh | sh
- name: Pull Ollama image
run: |
# TODO: cache the model. OLLAMA_MODELS defaults to ~ollama/.ollama/models.
ollama pull llama3.2:3b-instruct-fp16
- name: Start Ollama in background
run: |
nohup ollama run llama3.2:3b-instruct-fp16 > ollama.log 2>&1 &
- name: Set Up Environment and Install Dependencies
run: |
uv sync --extra dev --extra test
@ -61,21 +59,6 @@ jobs:
uv pip install -e .
llama stack build --template ollama --image-type venv
- name: Wait for Ollama to start
run: |
echo "Waiting for Ollama..."
for i in {1..30}; do
if curl -s http://localhost:11434 | grep -q "Ollama is running"; then
echo "Ollama is running!"
exit 0
fi
sleep 1
done
echo "Ollama failed to start"
ollama ps
ollama.log
exit 1
- name: Start Llama Stack server in background
if: matrix.client-type == 'http'
env:
@ -99,6 +82,17 @@ jobs:
cat server.log
exit 1
- name: Verify Ollama status is OK
if: matrix.client-type == 'http'
run: |
echo "Verifying Ollama status..."
ollama_status=$(curl -s -L http://127.0.0.1:8321/v1/providers/ollama|jq --raw-output .health.status)
echo "Ollama status: $ollama_status"
if [ "$ollama_status" != "OK" ]; then
echo "Ollama health check failed"
exit 1
fi
- name: Run Integration Tests
env:
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"

View file

@ -31,3 +31,12 @@ jobs:
- name: Verify if there are any diff files after pre-commit
run: |
git diff --exit-code || (echo "There are uncommitted changes, run pre-commit locally and commit again" && exit 1)
- name: Verify if there are any new files after pre-commit
run: |
unstaged_files=$(git ls-files --others --exclude-standard)
if [ -n "$unstaged_files" ]; then
echo "There are uncommitted new files, run pre-commit locally and commit again"
echo "$unstaged_files"
exit 1
fi

View file

@ -56,7 +56,7 @@ jobs:
python-version: '3.10'
- name: Install uv
uses: astral-sh/setup-uv@22695119d769bdb6f7032ad67b9bca0ef8c4a174 # v5.4.0
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
with:
python-version: "3.10"
@ -81,3 +81,29 @@ jobs:
run: |
source test/bin/activate
uv pip list
build-single-provider:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Install uv
uses: astral-sh/setup-uv@v5
with:
python-version: "3.10"
- name: Install LlamaStack
run: |
uv venv
source .venv/bin/activate
uv pip install -e .
- name: Build a single provider
run: |
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --image-type venv --image-name test --providers inference=remote::ollama

View file

@ -38,7 +38,7 @@ jobs:
with:
python-version: ${{ matrix.python }}
- uses: astral-sh/setup-uv@22695119d769bdb6f7032ad67b9bca0ef8c4a174 # v5.4.0
- uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
with:
python-version: ${{ matrix.python }}
enable-cache: false

View file

@ -41,7 +41,7 @@ jobs:
python-version: '3.11'
- name: Install the latest version of uv
uses: astral-sh/setup-uv@22695119d769bdb6f7032ad67b9bca0ef8c4a174 # v5.4.0
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
- name: Sync with uv
run: uv sync --extra docs

View file

@ -9,15 +9,16 @@
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb)
### ✨🎉 Llama 4 Support 🎉✨
We released [Version 0.2.0](https://github.com/meta-llama/llama-stack/releases/tag/v0.2.0) with support for the Llama 4 herd of models released by Meta.
You can now run Llama 4 models on Llama Stack.
<details>
<summary>👋 Click here to see how to run Llama 4 models on Llama Stack </summary>
\
*Note you need 8xH100 GPU-host to run these models*
```bash
pip install -U llama_stack
@ -67,6 +68,9 @@ print(f"Assistant> {response.completion_message.content}")
As more providers start supporting Llama 4, you can use them in Llama Stack as well. We are adding to the list. Stay tuned!
</details>
### Overview
Llama Stack standardizes the core building blocks that simplify AI application development. It codifies best practices across the Llama ecosystem. More specifically, it provides

View file

@ -3096,11 +3096,18 @@
"post": {
"responses": {
"200": {
"description": "OK",
"description": "Response from an OpenAI-compatible chat completion request. **OR** Chunk from a streaming response to an OpenAI-compatible chat completion request.",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/OpenAIChatCompletion"
"oneOf": [
{
"$ref": "#/components/schemas/OpenAIChatCompletion"
},
{
"$ref": "#/components/schemas/OpenAIChatCompletionChunk"
}
]
}
}
}
@ -7889,7 +7896,13 @@
"type": "object",
"properties": {
"status": {
"type": "string"
"type": "string",
"enum": [
"OK",
"Error",
"Not Implemented"
],
"title": "HealthStatus"
}
},
"additionalProperties": false,
@ -8084,6 +8097,31 @@
}
]
}
},
"health": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
@ -8091,7 +8129,8 @@
"api",
"provider_id",
"provider_type",
"config"
"config",
"health"
],
"title": "ProviderInfo"
},
@ -8825,7 +8864,17 @@
"description": "Must be \"assistant\" to identify this as the model's response"
},
"content": {
"$ref": "#/components/schemas/InterleavedContent",
"oneOf": [
{
"type": "string"
},
{
"type": "array",
"items": {
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
}
}
],
"description": "The content of the model's response"
},
"name": {
@ -8835,9 +8884,9 @@
"tool_calls": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ToolCall"
"$ref": "#/components/schemas/OpenAIChatCompletionToolCall"
},
"description": "List of tool calls. Each tool call is a ToolCall object."
"description": "List of tool calls. Each tool call is an OpenAIChatCompletionToolCall object."
}
},
"additionalProperties": false,
@ -8848,6 +8897,98 @@
"title": "OpenAIAssistantMessageParam",
"description": "A message containing the model's (assistant) response in an OpenAI-compatible chat completion request."
},
"OpenAIChatCompletionContentPartImageParam": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "image_url",
"default": "image_url"
},
"image_url": {
"$ref": "#/components/schemas/OpenAIImageURL"
}
},
"additionalProperties": false,
"required": [
"type",
"image_url"
],
"title": "OpenAIChatCompletionContentPartImageParam"
},
"OpenAIChatCompletionContentPartParam": {
"oneOf": [
{
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
},
{
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"text": "#/components/schemas/OpenAIChatCompletionContentPartTextParam",
"image_url": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
}
}
},
"OpenAIChatCompletionContentPartTextParam": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "text",
"default": "text"
},
"text": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"type",
"text"
],
"title": "OpenAIChatCompletionContentPartTextParam"
},
"OpenAIChatCompletionToolCall": {
"type": "object",
"properties": {
"index": {
"type": "integer"
},
"id": {
"type": "string"
},
"type": {
"type": "string",
"const": "function",
"default": "function"
},
"function": {
"$ref": "#/components/schemas/OpenAIChatCompletionToolCallFunction"
}
},
"additionalProperties": false,
"required": [
"type"
],
"title": "OpenAIChatCompletionToolCall"
},
"OpenAIChatCompletionToolCallFunction": {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"arguments": {
"type": "string"
}
},
"additionalProperties": false,
"title": "OpenAIChatCompletionToolCallFunction"
},
"OpenAIDeveloperMessageParam": {
"type": "object",
"properties": {
@ -8858,7 +8999,17 @@
"description": "Must be \"developer\" to identify this as a developer message"
},
"content": {
"$ref": "#/components/schemas/InterleavedContent",
"oneOf": [
{
"type": "string"
},
{
"type": "array",
"items": {
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
}
}
],
"description": "The content of the developer message"
},
"name": {
@ -8874,6 +9025,66 @@
"title": "OpenAIDeveloperMessageParam",
"description": "A message from the developer in an OpenAI-compatible chat completion request."
},
"OpenAIImageURL": {
"type": "object",
"properties": {
"url": {
"type": "string"
},
"detail": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"url"
],
"title": "OpenAIImageURL"
},
"OpenAIJSONSchema": {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"description": {
"type": "string"
},
"strict": {
"type": "boolean"
},
"schema": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
"required": [
"name"
],
"title": "OpenAIJSONSchema"
},
"OpenAIMessageParam": {
"oneOf": [
{
@ -8903,6 +9114,76 @@
}
}
},
"OpenAIResponseFormatJSONObject": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "json_object",
"default": "json_object"
}
},
"additionalProperties": false,
"required": [
"type"
],
"title": "OpenAIResponseFormatJSONObject"
},
"OpenAIResponseFormatJSONSchema": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "json_schema",
"default": "json_schema"
},
"json_schema": {
"$ref": "#/components/schemas/OpenAIJSONSchema"
}
},
"additionalProperties": false,
"required": [
"type",
"json_schema"
],
"title": "OpenAIResponseFormatJSONSchema"
},
"OpenAIResponseFormatParam": {
"oneOf": [
{
"$ref": "#/components/schemas/OpenAIResponseFormatText"
},
{
"$ref": "#/components/schemas/OpenAIResponseFormatJSONSchema"
},
{
"$ref": "#/components/schemas/OpenAIResponseFormatJSONObject"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"text": "#/components/schemas/OpenAIResponseFormatText",
"json_schema": "#/components/schemas/OpenAIResponseFormatJSONSchema",
"json_object": "#/components/schemas/OpenAIResponseFormatJSONObject"
}
}
},
"OpenAIResponseFormatText": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "text",
"default": "text"
}
},
"additionalProperties": false,
"required": [
"type"
],
"title": "OpenAIResponseFormatText"
},
"OpenAISystemMessageParam": {
"type": "object",
"properties": {
@ -8913,7 +9194,17 @@
"description": "Must be \"system\" to identify this as a system message"
},
"content": {
"$ref": "#/components/schemas/InterleavedContent",
"oneOf": [
{
"type": "string"
},
{
"type": "array",
"items": {
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
}
}
],
"description": "The content of the \"system prompt\". If multiple system messages are provided, they are concatenated. The underlying Llama Stack code may also add other system messages (for example, for formatting tool definitions)."
},
"name": {
@ -8943,7 +9234,17 @@
"description": "Unique identifier for the tool call this response is for"
},
"content": {
"$ref": "#/components/schemas/InterleavedContent",
"oneOf": [
{
"type": "string"
},
{
"type": "array",
"items": {
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
}
}
],
"description": "The response content from the tool"
}
},
@ -8966,7 +9267,17 @@
"description": "Must be \"user\" to identify this as a user message"
},
"content": {
"$ref": "#/components/schemas/InterleavedContent",
"oneOf": [
{
"type": "string"
},
{
"type": "array",
"items": {
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
}
}
],
"description": "The content of the message, which can include text and other media"
},
"name": {
@ -9094,10 +9405,7 @@
"description": "(Optional) The penalty for repeated tokens"
},
"response_format": {
"type": "object",
"additionalProperties": {
"type": "string"
},
"$ref": "#/components/schemas/OpenAIResponseFormatParam",
"description": "(Optional) The response format to use"
},
"seed": {
@ -9274,6 +9582,46 @@
"title": "OpenAIChatCompletion",
"description": "Response from an OpenAI-compatible chat completion request."
},
"OpenAIChatCompletionChunk": {
"type": "object",
"properties": {
"id": {
"type": "string",
"description": "The ID of the chat completion"
},
"choices": {
"type": "array",
"items": {
"$ref": "#/components/schemas/OpenAIChunkChoice"
},
"description": "List of choices"
},
"object": {
"type": "string",
"const": "chat.completion.chunk",
"default": "chat.completion.chunk",
"description": "The object type, which will be \"chat.completion.chunk\""
},
"created": {
"type": "integer",
"description": "The Unix timestamp in seconds when the chat completion was created"
},
"model": {
"type": "string",
"description": "The model that was used to generate the chat completion"
}
},
"additionalProperties": false,
"required": [
"id",
"choices",
"object",
"created",
"model"
],
"title": "OpenAIChatCompletionChunk",
"description": "Chunk from a streaming response to an OpenAI-compatible chat completion request."
},
"OpenAIChoice": {
"type": "object",
"properties": {
@ -9286,10 +9634,12 @@
"description": "The reason the model stopped generating"
},
"index": {
"type": "integer"
"type": "integer",
"description": "The index of the choice"
},
"logprobs": {
"$ref": "#/components/schemas/OpenAIChoiceLogprobs"
"$ref": "#/components/schemas/OpenAIChoiceLogprobs",
"description": "(Optional) The log probabilities for the tokens in the message"
}
},
"additionalProperties": false,
@ -9301,6 +9651,33 @@
"title": "OpenAIChoice",
"description": "A choice from an OpenAI-compatible chat completion response."
},
"OpenAIChoiceDelta": {
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "(Optional) The content of the delta"
},
"refusal": {
"type": "string",
"description": "(Optional) The refusal of the delta"
},
"role": {
"type": "string",
"description": "(Optional) The role of the delta"
},
"tool_calls": {
"type": "array",
"items": {
"$ref": "#/components/schemas/OpenAIChatCompletionToolCall"
},
"description": "(Optional) The tool calls of the delta"
}
},
"additionalProperties": false,
"title": "OpenAIChoiceDelta",
"description": "A delta from an OpenAI-compatible chat completion streaming response."
},
"OpenAIChoiceLogprobs": {
"type": "object",
"properties": {
@ -9308,19 +9685,50 @@
"type": "array",
"items": {
"$ref": "#/components/schemas/OpenAITokenLogProb"
}
},
"description": "(Optional) The log probabilities for the tokens in the message"
},
"refusal": {
"type": "array",
"items": {
"$ref": "#/components/schemas/OpenAITokenLogProb"
}
},
"description": "(Optional) The log probabilities for the tokens in the message"
}
},
"additionalProperties": false,
"title": "OpenAIChoiceLogprobs",
"description": "The log probabilities for the tokens in the message from an OpenAI-compatible chat completion response."
},
"OpenAIChunkChoice": {
"type": "object",
"properties": {
"delta": {
"$ref": "#/components/schemas/OpenAIChoiceDelta",
"description": "The delta from the chunk"
},
"finish_reason": {
"type": "string",
"description": "The reason the model stopped generating"
},
"index": {
"type": "integer",
"description": "The index of the choice"
},
"logprobs": {
"$ref": "#/components/schemas/OpenAIChoiceLogprobs",
"description": "(Optional) The log probabilities for the tokens in the message"
}
},
"additionalProperties": false,
"required": [
"delta",
"finish_reason",
"index"
],
"title": "OpenAIChunkChoice",
"description": "A chunk choice from an OpenAI-compatible chat completion streaming response."
},
"OpenAITokenLogProb": {
"type": "object",
"properties": {

View file

@ -2135,11 +2135,15 @@ paths:
post:
responses:
'200':
description: OK
description: >-
Response from an OpenAI-compatible chat completion request. **OR** Chunk
from a streaming response to an OpenAI-compatible chat completion request.
content:
application/json:
schema:
$ref: '#/components/schemas/OpenAIChatCompletion'
oneOf:
- $ref: '#/components/schemas/OpenAIChatCompletion'
- $ref: '#/components/schemas/OpenAIChatCompletionChunk'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@ -5463,6 +5467,11 @@ components:
properties:
status:
type: string
enum:
- OK
- Error
- Not Implemented
title: HealthStatus
additionalProperties: false
required:
- status
@ -5574,12 +5583,23 @@ components:
- type: string
- type: array
- type: object
health:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
additionalProperties: false
required:
- api
- provider_id
- provider_type
- config
- health
title: ProviderInfo
InvokeToolRequest:
type: object
@ -6057,7 +6077,11 @@ components:
description: >-
Must be "assistant" to identify this as the model's response
content:
$ref: '#/components/schemas/InterleavedContent'
oneOf:
- type: string
- type: array
items:
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
description: The content of the model's response
name:
type: string
@ -6066,9 +6090,10 @@ components:
tool_calls:
type: array
items:
$ref: '#/components/schemas/ToolCall'
$ref: '#/components/schemas/OpenAIChatCompletionToolCall'
description: >-
List of tool calls. Each tool call is a ToolCall object.
List of tool calls. Each tool call is an OpenAIChatCompletionToolCall
object.
additionalProperties: false
required:
- role
@ -6077,6 +6102,70 @@ components:
description: >-
A message containing the model's (assistant) response in an OpenAI-compatible
chat completion request.
"OpenAIChatCompletionContentPartImageParam":
type: object
properties:
type:
type: string
const: image_url
default: image_url
image_url:
$ref: '#/components/schemas/OpenAIImageURL'
additionalProperties: false
required:
- type
- image_url
title: >-
OpenAIChatCompletionContentPartImageParam
OpenAIChatCompletionContentPartParam:
oneOf:
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
discriminator:
propertyName: type
mapping:
text: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
image_url: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
OpenAIChatCompletionContentPartTextParam:
type: object
properties:
type:
type: string
const: text
default: text
text:
type: string
additionalProperties: false
required:
- type
- text
title: OpenAIChatCompletionContentPartTextParam
OpenAIChatCompletionToolCall:
type: object
properties:
index:
type: integer
id:
type: string
type:
type: string
const: function
default: function
function:
$ref: '#/components/schemas/OpenAIChatCompletionToolCallFunction'
additionalProperties: false
required:
- type
title: OpenAIChatCompletionToolCall
OpenAIChatCompletionToolCallFunction:
type: object
properties:
name:
type: string
arguments:
type: string
additionalProperties: false
title: OpenAIChatCompletionToolCallFunction
OpenAIDeveloperMessageParam:
type: object
properties:
@ -6087,7 +6176,11 @@ components:
description: >-
Must be "developer" to identify this as a developer message
content:
$ref: '#/components/schemas/InterleavedContent'
oneOf:
- type: string
- type: array
items:
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
description: The content of the developer message
name:
type: string
@ -6100,6 +6193,40 @@ components:
title: OpenAIDeveloperMessageParam
description: >-
A message from the developer in an OpenAI-compatible chat completion request.
OpenAIImageURL:
type: object
properties:
url:
type: string
detail:
type: string
additionalProperties: false
required:
- url
title: OpenAIImageURL
OpenAIJSONSchema:
type: object
properties:
name:
type: string
description:
type: string
strict:
type: boolean
schema:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
additionalProperties: false
required:
- name
title: OpenAIJSONSchema
OpenAIMessageParam:
oneOf:
- $ref: '#/components/schemas/OpenAIUserMessageParam'
@ -6115,6 +6242,53 @@ components:
assistant: '#/components/schemas/OpenAIAssistantMessageParam'
tool: '#/components/schemas/OpenAIToolMessageParam'
developer: '#/components/schemas/OpenAIDeveloperMessageParam'
OpenAIResponseFormatJSONObject:
type: object
properties:
type:
type: string
const: json_object
default: json_object
additionalProperties: false
required:
- type
title: OpenAIResponseFormatJSONObject
OpenAIResponseFormatJSONSchema:
type: object
properties:
type:
type: string
const: json_schema
default: json_schema
json_schema:
$ref: '#/components/schemas/OpenAIJSONSchema'
additionalProperties: false
required:
- type
- json_schema
title: OpenAIResponseFormatJSONSchema
OpenAIResponseFormatParam:
oneOf:
- $ref: '#/components/schemas/OpenAIResponseFormatText'
- $ref: '#/components/schemas/OpenAIResponseFormatJSONSchema'
- $ref: '#/components/schemas/OpenAIResponseFormatJSONObject'
discriminator:
propertyName: type
mapping:
text: '#/components/schemas/OpenAIResponseFormatText'
json_schema: '#/components/schemas/OpenAIResponseFormatJSONSchema'
json_object: '#/components/schemas/OpenAIResponseFormatJSONObject'
OpenAIResponseFormatText:
type: object
properties:
type:
type: string
const: text
default: text
additionalProperties: false
required:
- type
title: OpenAIResponseFormatText
OpenAISystemMessageParam:
type: object
properties:
@ -6125,7 +6299,11 @@ components:
description: >-
Must be "system" to identify this as a system message
content:
$ref: '#/components/schemas/InterleavedContent'
oneOf:
- type: string
- type: array
items:
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
description: >-
The content of the "system prompt". If multiple system messages are provided,
they are concatenated. The underlying Llama Stack code may also add other
@ -6155,7 +6333,11 @@ components:
description: >-
Unique identifier for the tool call this response is for
content:
$ref: '#/components/schemas/InterleavedContent'
oneOf:
- type: string
- type: array
items:
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
description: The response content from the tool
additionalProperties: false
required:
@ -6176,7 +6358,11 @@ components:
description: >-
Must be "user" to identify this as a user message
content:
$ref: '#/components/schemas/InterleavedContent'
oneOf:
- type: string
- type: array
items:
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
description: >-
The content of the message, which can include text and other media
name:
@ -6262,9 +6448,7 @@ components:
description: >-
(Optional) The penalty for repeated tokens
response_format:
type: object
additionalProperties:
type: string
$ref: '#/components/schemas/OpenAIResponseFormatParam'
description: (Optional) The response format to use
seed:
type: integer
@ -6370,6 +6554,41 @@ components:
title: OpenAIChatCompletion
description: >-
Response from an OpenAI-compatible chat completion request.
OpenAIChatCompletionChunk:
type: object
properties:
id:
type: string
description: The ID of the chat completion
choices:
type: array
items:
$ref: '#/components/schemas/OpenAIChunkChoice'
description: List of choices
object:
type: string
const: chat.completion.chunk
default: chat.completion.chunk
description: >-
The object type, which will be "chat.completion.chunk"
created:
type: integer
description: >-
The Unix timestamp in seconds when the chat completion was created
model:
type: string
description: >-
The model that was used to generate the chat completion
additionalProperties: false
required:
- id
- choices
- object
- created
- model
title: OpenAIChatCompletionChunk
description: >-
Chunk from a streaming response to an OpenAI-compatible chat completion request.
OpenAIChoice:
type: object
properties:
@ -6381,8 +6600,11 @@ components:
description: The reason the model stopped generating
index:
type: integer
description: The index of the choice
logprobs:
$ref: '#/components/schemas/OpenAIChoiceLogprobs'
description: >-
(Optional) The log probabilities for the tokens in the message
additionalProperties: false
required:
- message
@ -6391,6 +6613,27 @@ components:
title: OpenAIChoice
description: >-
A choice from an OpenAI-compatible chat completion response.
OpenAIChoiceDelta:
type: object
properties:
content:
type: string
description: (Optional) The content of the delta
refusal:
type: string
description: (Optional) The refusal of the delta
role:
type: string
description: (Optional) The role of the delta
tool_calls:
type: array
items:
$ref: '#/components/schemas/OpenAIChatCompletionToolCall'
description: (Optional) The tool calls of the delta
additionalProperties: false
title: OpenAIChoiceDelta
description: >-
A delta from an OpenAI-compatible chat completion streaming response.
OpenAIChoiceLogprobs:
type: object
properties:
@ -6398,15 +6641,43 @@ components:
type: array
items:
$ref: '#/components/schemas/OpenAITokenLogProb'
description: >-
(Optional) The log probabilities for the tokens in the message
refusal:
type: array
items:
$ref: '#/components/schemas/OpenAITokenLogProb'
description: >-
(Optional) The log probabilities for the tokens in the message
additionalProperties: false
title: OpenAIChoiceLogprobs
description: >-
The log probabilities for the tokens in the message from an OpenAI-compatible
chat completion response.
OpenAIChunkChoice:
type: object
properties:
delta:
$ref: '#/components/schemas/OpenAIChoiceDelta'
description: The delta from the chunk
finish_reason:
type: string
description: The reason the model stopped generating
index:
type: integer
description: The index of the choice
logprobs:
$ref: '#/components/schemas/OpenAIChoiceLogprobs'
description: >-
(Optional) The log probabilities for the tokens in the message
additionalProperties: false
required:
- delta
- finish_reason
- index
title: OpenAIChunkChoice
description: >-
A chunk choice from an OpenAI-compatible chat completion streaming response.
OpenAITokenLogProb:
type: object
properties:

View file

@ -24,7 +24,7 @@ The key files in the app are `ExampleLlamaStackLocalInference.kt`, `ExampleLlama
Add the following dependency in your `build.gradle.kts` file:
```
dependencies {
implementation("com.llama.llamastack:llama-stack-client-kotlin:0.1.4.2")
implementation("com.llama.llamastack:llama-stack-client-kotlin:0.2.2")
}
```
This will download jar files in your gradle cache in a directory like `~/.gradle/caches/modules-2/files-2.1/com.llama.llamastack/`
@ -37,11 +37,7 @@ For local inferencing, it is required to include the ExecuTorch library into you
Include the ExecuTorch library by:
1. Download the `download-prebuilt-et-lib.sh` script file from the [llama-stack-client-kotlin-client-local](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release/llama-stack-client-kotlin-client-local/download-prebuilt-et-lib.sh) directory to your local machine.
2. Move the script to the top level of your Android app where the app directory resides:
<p align="center">
<img src="https://github.com/meta-llama/llama-stack-client-kotlin/blob/latest-release/doc/img/example_android_app_directory.png" style="width:300px">
</p>
2. Move the script to the top level of your Android app where the `app` directory resides.
3. Run `sh download-prebuilt-et-lib.sh` to create an `app/libs` directory and download the `executorch.aar` in that path. This generates an ExecuTorch library for the XNNPACK delegate.
4. Add the `executorch.aar` dependency in your `build.gradle.kts` file:
```
@ -52,6 +48,8 @@ dependencies {
}
```
See other dependencies for the local RAG in Android app [README](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release/examples/android_app#quick-start).
## Llama Stack APIs in Your Android App
Breaking down the demo app, this section will show the core pieces that are used to initialize and run inference with Llama Stack using the Kotlin library.
@ -60,7 +58,7 @@ Start a Llama Stack server on localhost. Here is an example of how you can do th
```
conda create -n stack-fireworks python=3.10
conda activate stack-fireworks
pip install --no-cache llama-stack==0.1.4
pip install --no-cache llama-stack==0.2.2
llama stack build --template fireworks --image-type conda
export FIREWORKS_API_KEY=<SOME_KEY>
llama stack run fireworks --port 5050

View file

@ -43,7 +43,9 @@ The following models are available by default:
- `groq/llama-3.3-70b-versatile (aliases: meta-llama/Llama-3.3-70B-Instruct)`
- `groq/llama-3.2-3b-preview (aliases: meta-llama/Llama-3.2-3B-Instruct)`
- `groq/llama-4-scout-17b-16e-instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
- `groq/meta-llama/llama-4-scout-17b-16e-instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
- `groq/llama-4-maverick-17b-128e-instruct (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct)`
- `groq/meta-llama/llama-4-maverick-17b-128e-instruct (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct)`
### Prerequisite: API Keys

View file

@ -41,7 +41,7 @@ The following environment variables can be configured:
## Setting up vLLM server
In the following sections, we'll use either AMD and NVIDIA GPUs to serve as hardware accelerators for the vLLM
In the following sections, we'll use AMD, NVIDIA or Intel GPUs to serve as hardware accelerators for the vLLM
server, which acts as both the LLM inference provider and the safety provider. Note that vLLM also
[supports many other hardware accelerators](https://docs.vllm.ai/en/latest/getting_started/installation.html) and
that we only use GPUs here for demonstration purposes.
@ -162,6 +162,55 @@ docker run \
--port $SAFETY_PORT
```
### Setting up vLLM server on Intel GPU
Refer to [vLLM Documentation for XPU](https://docs.vllm.ai/en/v0.8.2/getting_started/installation/gpu.html?device=xpu) to get a vLLM endpoint. In addition to vLLM side setup which guides towards installing vLLM from sources orself-building vLLM Docker container, Intel provides prebuilt vLLM container to use on systems with Intel GPUs supported by PyTorch XPU backend:
- [intel/vllm](https://hub.docker.com/r/intel/vllm)
Here is a sample script to start a vLLM server locally via Docker using Intel provided container:
```bash
export INFERENCE_PORT=8000
export INFERENCE_MODEL=meta-llama/Llama-3.2-1B-Instruct
export ZE_AFFINITY_MASK=0
docker run \
--pull always \
--device /dev/dri \
-v /dev/dri/by-path:/dev/dri/by-path \
-v ~/.cache/huggingface:/root/.cache/huggingface \
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
--env ZE_AFFINITY_MASK=$ZE_AFFINITY_MASK \
-p $INFERENCE_PORT:$INFERENCE_PORT \
--ipc=host \
intel/vllm:xpu \
--gpu-memory-utilization 0.7 \
--model $INFERENCE_MODEL \
--port $INFERENCE_PORT
```
If you are using Llama Stack Safety / Shield APIs, then you will need to also run another instance of a vLLM with a corresponding safety model like `meta-llama/Llama-Guard-3-1B` using a script like:
```bash
export SAFETY_PORT=8081
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
export ZE_AFFINITY_MASK=1
docker run \
--pull always \
--device /dev/dri \
-v /dev/dri/by-path:/dev/dri/by-path \
-v ~/.cache/huggingface:/root/.cache/huggingface \
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
--env ZE_AFFINITY_MASK=$ZE_AFFINITY_MASK \
-p $SAFETY_PORT:$SAFETY_PORT \
--ipc=host \
intel/vllm:xpu \
--gpu-memory-utilization 0.7 \
--model $SAFETY_MODEL \
--port $SAFETY_PORT
```
## Running Llama Stack
Now you are ready to run Llama Stack with vLLM as the inference provider. You can do this via Conda (build code) or Docker which has a pre-built image.

View file

@ -18,7 +18,7 @@ from typing import (
)
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated
from typing_extensions import Annotated, TypedDict
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
from llama_stack.apis.models import Model
@ -442,6 +442,37 @@ class EmbeddingsResponse(BaseModel):
embeddings: List[List[float]]
@json_schema_type
class OpenAIChatCompletionContentPartTextParam(BaseModel):
type: Literal["text"] = "text"
text: str
@json_schema_type
class OpenAIImageURL(BaseModel):
url: str
detail: Optional[str] = None
@json_schema_type
class OpenAIChatCompletionContentPartImageParam(BaseModel):
type: Literal["image_url"] = "image_url"
image_url: OpenAIImageURL
OpenAIChatCompletionContentPartParam = Annotated[
Union[
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionContentPartImageParam,
],
Field(discriminator="type"),
]
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
OpenAIChatCompletionMessageContent = Union[str, List[OpenAIChatCompletionContentPartParam]]
@json_schema_type
class OpenAIUserMessageParam(BaseModel):
"""A message from the user in an OpenAI-compatible chat completion request.
@ -452,7 +483,7 @@ class OpenAIUserMessageParam(BaseModel):
"""
role: Literal["user"] = "user"
content: InterleavedContent
content: OpenAIChatCompletionMessageContent
name: Optional[str] = None
@ -466,10 +497,24 @@ class OpenAISystemMessageParam(BaseModel):
"""
role: Literal["system"] = "system"
content: InterleavedContent
content: OpenAIChatCompletionMessageContent
name: Optional[str] = None
@json_schema_type
class OpenAIChatCompletionToolCallFunction(BaseModel):
name: Optional[str] = None
arguments: Optional[str] = None
@json_schema_type
class OpenAIChatCompletionToolCall(BaseModel):
index: Optional[int] = None
id: Optional[str] = None
type: Literal["function"] = "function"
function: Optional[OpenAIChatCompletionToolCallFunction] = None
@json_schema_type
class OpenAIAssistantMessageParam(BaseModel):
"""A message containing the model's (assistant) response in an OpenAI-compatible chat completion request.
@ -477,13 +522,13 @@ class OpenAIAssistantMessageParam(BaseModel):
:param role: Must be "assistant" to identify this as the model's response
:param content: The content of the model's response
:param name: (Optional) The name of the assistant message participant.
:param tool_calls: List of tool calls. Each tool call is a ToolCall object.
:param tool_calls: List of tool calls. Each tool call is an OpenAIChatCompletionToolCall object.
"""
role: Literal["assistant"] = "assistant"
content: InterleavedContent
content: OpenAIChatCompletionMessageContent
name: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = Field(default_factory=list)
@json_schema_type
@ -497,7 +542,7 @@ class OpenAIToolMessageParam(BaseModel):
role: Literal["tool"] = "tool"
tool_call_id: str
content: InterleavedContent
content: OpenAIChatCompletionMessageContent
@json_schema_type
@ -510,7 +555,7 @@ class OpenAIDeveloperMessageParam(BaseModel):
"""
role: Literal["developer"] = "developer"
content: InterleavedContent
content: OpenAIChatCompletionMessageContent
name: Optional[str] = None
@ -527,6 +572,46 @@ OpenAIMessageParam = Annotated[
register_schema(OpenAIMessageParam, name="OpenAIMessageParam")
@json_schema_type
class OpenAIResponseFormatText(BaseModel):
type: Literal["text"] = "text"
@json_schema_type
class OpenAIJSONSchema(TypedDict, total=False):
name: str
description: Optional[str] = None
strict: Optional[bool] = None
# Pydantic BaseModel cannot be used with a schema param, since it already
# has one. And, we don't want to alias here because then have to handle
# that alias when converting to OpenAI params. So, to support schema,
# we use a TypedDict.
schema: Optional[Dict[str, Any]] = None
@json_schema_type
class OpenAIResponseFormatJSONSchema(BaseModel):
type: Literal["json_schema"] = "json_schema"
json_schema: OpenAIJSONSchema
@json_schema_type
class OpenAIResponseFormatJSONObject(BaseModel):
type: Literal["json_object"] = "json_object"
OpenAIResponseFormatParam = Annotated[
Union[
OpenAIResponseFormatText,
OpenAIResponseFormatJSONSchema,
OpenAIResponseFormatJSONObject,
],
Field(discriminator="type"),
]
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
@json_schema_type
class OpenAITopLogProb(BaseModel):
"""The top log probability for a token from an OpenAI-compatible chat completion response.
@ -561,22 +646,54 @@ class OpenAITokenLogProb(BaseModel):
class OpenAIChoiceLogprobs(BaseModel):
"""The log probabilities for the tokens in the message from an OpenAI-compatible chat completion response.
:content: (Optional) The log probabilities for the tokens in the message
:refusal: (Optional) The log probabilities for the tokens in the message
:param content: (Optional) The log probabilities for the tokens in the message
:param refusal: (Optional) The log probabilities for the tokens in the message
"""
content: Optional[List[OpenAITokenLogProb]] = None
refusal: Optional[List[OpenAITokenLogProb]] = None
@json_schema_type
class OpenAIChoiceDelta(BaseModel):
"""A delta from an OpenAI-compatible chat completion streaming response.
:param content: (Optional) The content of the delta
:param refusal: (Optional) The refusal of the delta
:param role: (Optional) The role of the delta
:param tool_calls: (Optional) The tool calls of the delta
"""
content: Optional[str] = None
refusal: Optional[str] = None
role: Optional[str] = None
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None
@json_schema_type
class OpenAIChunkChoice(BaseModel):
"""A chunk choice from an OpenAI-compatible chat completion streaming response.
:param delta: The delta from the chunk
:param finish_reason: The reason the model stopped generating
:param index: The index of the choice
:param logprobs: (Optional) The log probabilities for the tokens in the message
"""
delta: OpenAIChoiceDelta
finish_reason: str
index: int
logprobs: Optional[OpenAIChoiceLogprobs] = None
@json_schema_type
class OpenAIChoice(BaseModel):
"""A choice from an OpenAI-compatible chat completion response.
:param message: The message from the model
:param finish_reason: The reason the model stopped generating
:index: The index of the choice
:logprobs: (Optional) The log probabilities for the tokens in the message
:param index: The index of the choice
:param logprobs: (Optional) The log probabilities for the tokens in the message
"""
message: OpenAIMessageParam
@ -603,6 +720,24 @@ class OpenAIChatCompletion(BaseModel):
model: str
@json_schema_type
class OpenAIChatCompletionChunk(BaseModel):
"""Chunk from a streaming response to an OpenAI-compatible chat completion request.
:param id: The ID of the chat completion
:param choices: List of choices
:param object: The object type, which will be "chat.completion.chunk"
:param created: The Unix timestamp in seconds when the chat completion was created
:param model: The model that was used to generate the chat completion
"""
id: str
choices: List[OpenAIChunkChoice]
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int
model: str
@json_schema_type
class OpenAICompletionLogprobs(BaseModel):
"""The log probabilities for the tokens in the message from an OpenAI-compatible completion response.
@ -872,7 +1007,7 @@ class Inference(Protocol):
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
@ -883,7 +1018,7 @@ class Inference(Protocol):
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> OpenAIChatCompletion:
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.

View file

@ -8,6 +8,7 @@ from typing import List, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.schema_utils import json_schema_type, webmethod
@ -20,8 +21,7 @@ class RouteInfo(BaseModel):
@json_schema_type
class HealthInfo(BaseModel):
status: str
# TODO: add a provider level status
status: HealthStatus
@json_schema_type

View file

@ -8,6 +8,7 @@ from typing import Any, Dict, List, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.providers.datatypes import HealthResponse
from llama_stack.schema_utils import json_schema_type, webmethod
@ -17,6 +18,7 @@ class ProviderInfo(BaseModel):
provider_id: str
provider_type: str
config: Dict[str, Any]
health: HealthResponse
class ListProvidersResponse(BaseModel):

View file

@ -89,6 +89,43 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
color="red",
)
sys.exit(1)
elif args.providers:
providers = dict()
for api_provider in args.providers.split(","):
if "=" not in api_provider:
cprint(
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
color="red",
)
sys.exit(1)
api, provider = api_provider.split("=")
providers_for_api = get_provider_registry().get(Api(api), None)
if providers_for_api is None:
cprint(
f"{api} is not a valid API.",
color="red",
)
sys.exit(1)
if provider in providers_for_api:
providers.setdefault(api, []).append(provider)
else:
cprint(
f"{provider} is not a valid provider for the {api} API.",
color="red",
)
sys.exit(1)
distribution_spec = DistributionSpec(
providers=providers,
description=",".join(args.providers),
)
if not args.image_type:
cprint(
f"Please specify a image-type (container | conda | venv) for {args.template}",
color="red",
)
sys.exit(1)
build_config = BuildConfig(image_type=args.image_type, distribution_spec=distribution_spec)
elif not args.config and not args.template:
name = prompt(
"> Enter a name for your Llama Stack (e.g. my-local-stack): ",

View file

@ -75,6 +75,12 @@ the build. If not specified, currently active environment will be used if found.
default=False,
help="Run the stack after building using the same image type, name, and other applicable arguments",
)
self.parser.add_argument(
"--providers",
type=str,
default=None,
help="Build a config for a list of providers and only those providers. This list is formatted like: api1=provider1,api2=provider2. Where there can be multiple providers per API.",
)
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
# always keep implementation completely silo-ed away from CLI so CLI

View file

@ -17,6 +17,7 @@ from llama_stack.apis.inspect import (
)
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.providers.datatypes import HealthStatus
class DistributionInspectConfig(BaseModel):
@ -58,7 +59,7 @@ class DistributionInspectImpl(Inspect):
return ListRoutesResponse(data=ret)
async def health(self) -> HealthInfo:
return HealthInfo(status="OK")
return HealthInfo(status=HealthStatus.OK)
async def version(self) -> VersionInfo:
return VersionInfo(version=version("llama-stack"))

View file

@ -43,9 +43,9 @@ from llama_stack.distribution.server.endpoints import (
from llama_stack.distribution.stack import (
construct_stack,
get_stack_run_config_from_template,
redact_sensitive_fields,
replace_env_vars,
)
from llama_stack.distribution.utils.config import redact_sensitive_fields
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
from llama_stack.distribution.utils.exec import in_notebook
from llama_stack.providers.utils.telemetry.tracing import (

View file

@ -4,14 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import Any, Dict
from pydantic import BaseModel
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
from .datatypes import StackRunConfig
from .stack import redact_sensitive_fields
from .utils.config import redact_sensitive_fields
logger = get_logger(name=__name__, category="core")
@ -41,19 +44,24 @@ class ProviderImpl(Providers):
async def list_providers(self) -> ListProvidersResponse:
run_config = self.config.run_config
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
providers_health = await self.get_providers_health()
ret = []
for api, providers in safe_config.providers.items():
ret.extend(
[
for p in providers:
ret.append(
ProviderInfo(
api=api,
provider_id=p.provider_id,
provider_type=p.provider_type,
config=p.config,
health=providers_health.get(api, {}).get(
p.provider_id,
HealthResponse(
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
),
),
)
for p in providers
]
)
)
return ListProvidersResponse(data=ret)
@ -64,3 +72,57 @@ class ProviderImpl(Providers):
return p
raise ValueError(f"Provider {provider_id} not found")
async def get_providers_health(self) -> Dict[str, Dict[str, HealthResponse]]:
"""Get health status for all providers.
Returns:
Dict[str, Dict[str, HealthResponse]]: A dictionary mapping API names to provider health statuses.
Each API maps to a dictionary of provider IDs to their health responses.
"""
providers_health: Dict[str, Dict[str, HealthResponse]] = {}
timeout = 1.0
async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None:
# Skip special implementations (inspect/providers) that don't have provider specs
if not hasattr(impl, "__provider_spec__"):
return None
api_name = impl.__provider_spec__.api.name
if not hasattr(impl, "health"):
return (
api_name,
HealthResponse(
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
),
)
try:
health = await asyncio.wait_for(impl.health(), timeout=timeout)
return api_name, health
except asyncio.TimeoutError:
return (
api_name,
HealthResponse(
status=HealthStatus.ERROR, message=f"Health check timed out after {timeout} seconds"
),
)
except Exception as e:
return (
api_name,
HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"),
)
# Create tasks for all providers
tasks = [check_provider_health(impl) for impl in self.deps.values()]
# Wait for all health checks to complete
results = await asyncio.gather(*tasks)
# Organize results by API and provider ID
for result in results:
if result is None: # Skip special implementations
continue
api_name, health_response = result
providers_health[api_name] = health_response
return providers_health

View file

@ -41,7 +41,6 @@ from llama_stack.providers.datatypes import (
Api,
BenchmarksProtocolPrivate,
DatasetsProtocolPrivate,
InlineProviderSpec,
ModelsProtocolPrivate,
ProviderSpec,
RemoteProviderConfig,
@ -230,46 +229,6 @@ def sort_providers_by_deps(
{k: list(v.values()) for k, v in providers_with_specs.items()}
)
# Append built-in "inspect" provider
apis = [x[1].spec.api for x in sorted_providers]
sorted_providers.append(
(
"inspect",
ProviderWithSpec(
provider_id="__builtin__",
provider_type="__builtin__",
config={"run_config": run_config.model_dump()},
spec=InlineProviderSpec(
api=Api.inspect,
provider_type="__builtin__",
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
module="llama_stack.distribution.inspect",
api_dependencies=apis,
deps__=[x.value for x in apis],
),
),
)
)
sorted_providers.append(
(
"providers",
ProviderWithSpec(
provider_id="__builtin__",
provider_type="__builtin__",
config={"run_config": run_config.model_dump()},
spec=InlineProviderSpec(
api=Api.providers,
provider_type="__builtin__",
config_class="llama_stack.distribution.providers.ProviderImplConfig",
module="llama_stack.distribution.providers",
api_dependencies=apis,
deps__=[x.value for x in apis],
),
),
)
)
logger.debug(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers:
logger.debug(f" {api_str} => {provider.provider_id}")

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import time
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
@ -37,7 +38,13 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.scoring import (
@ -60,7 +67,7 @@ from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import RoutingTable
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="core")
@ -530,7 +537,7 @@ class InferenceRouter(Inference):
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
@ -541,7 +548,7 @@ class InferenceRouter(Inference):
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> OpenAIChatCompletion:
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
logger.debug(
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
)
@ -580,6 +587,29 @@ class InferenceRouter(Inference):
provider = self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.openai_chat_completion(**params)
async def health(self) -> Dict[str, HealthResponse]:
health_statuses = {}
timeout = 0.5
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
try:
# check if the provider has a health method
if not hasattr(impl, "health"):
continue
health = await asyncio.wait_for(impl.health(), timeout=timeout)
health_statuses[provider_id] = health
except asyncio.TimeoutError:
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR,
message=f"Health check timed out after {timeout} seconds",
)
except NotImplementedError:
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
except Exception as e:
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
)
return health_statuses
class SafetyRouter(Safety):
def __init__(

View file

@ -38,10 +38,10 @@ from llama_stack.distribution.server.endpoints import (
)
from llama_stack.distribution.stack import (
construct_stack,
redact_sensitive_fields,
replace_env_vars,
validate_env_pair,
)
from llama_stack.distribution.utils.config import redact_sensitive_fields
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
@ -229,15 +229,30 @@ class TracingMiddleware:
def __init__(self, app, impls):
self.app = app
self.impls = impls
# FastAPI built-in paths that should bypass custom routing
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
async def __call__(self, scope, receive, send):
if scope.get("type") == "lifespan":
return await self.app(scope, receive, send)
path = scope.get("path", "")
# Check if the path is a FastAPI built-in path
if path.startswith(self.fastapi_paths):
# Pass through to FastAPI's built-in handlers
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
return await self.app(scope, receive, send)
if not hasattr(self, "endpoint_impls"):
self.endpoint_impls = initialize_endpoint_impls(self.impls)
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
try:
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
except ValueError:
# If no matching endpoint is found, pass through to FastAPI
logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI")
return await self.app(scope, receive, send)
trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path})
@ -388,7 +403,12 @@ def main(args: Optional[argparse.Namespace] = None):
safe_config = redact_sensitive_fields(config.model_dump())
logger.info(yaml.dump(safe_config, indent=2))
app = FastAPI(lifespan=lifespan)
app = FastAPI(
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json",
)
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware)

View file

@ -35,6 +35,8 @@ from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.datatypes import Provider, StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl
from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
@ -119,26 +121,6 @@ class EnvVarError(Exception):
super().__init__(f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}")
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
"""Redact sensitive information from config before printing."""
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
result = {}
for k, v in d.items():
if isinstance(v, dict):
result[k] = _redact_dict(v)
elif isinstance(v, list):
result[k] = [_redact_dict(i) if isinstance(i, dict) else i for i in v]
elif any(pattern in k.lower() for pattern in sensitive_patterns):
result[k] = "********"
else:
result[k] = v
return result
return _redact_dict(data)
def replace_env_vars(config: Any, path: str = "") -> Any:
if isinstance(config, dict):
result = {}
@ -215,6 +197,26 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
) from e
def add_internal_implementations(impls: Dict[Api, Any], run_config: StackRunConfig) -> None:
"""Add internal implementations (inspect and providers) to the implementations dictionary.
Args:
impls: Dictionary of API implementations
run_config: Stack run configuration
"""
inspect_impl = DistributionInspectImpl(
DistributionInspectConfig(run_config=run_config),
deps=impls,
)
impls[Api.inspect] = inspect_impl
providers_impl = ProviderImpl(
ProviderImplConfig(run_config=run_config),
deps=impls,
)
impls[Api.providers] = providers_impl
# Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config.
async def construct_stack(
@ -222,6 +224,10 @@ async def construct_stack(
) -> Dict[Api, Any]:
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
# Add internal implementations after all other providers are resolved
add_internal_implementations(impls, run_config)
await register_resources(run_config, impls)
return impls

View file

@ -56,6 +56,17 @@ def tool_chat_page():
st.subheader(f"Active Tools: 🛠 {len(active_tool_list)}")
st.json(active_tool_list)
st.subheader("Chat Configurations")
max_tokens = st.slider(
"Max Tokens",
min_value=0,
max_value=4096,
value=512,
step=1,
help="The maximum number of tokens to generate",
on_change=reset_agent,
)
@st.cache_resource
def create_agent():
return Agent(
@ -63,9 +74,7 @@ def tool_chat_page():
model=model,
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
tools=toolgroup_selection,
sampling_params={
"strategy": {"type": "greedy"},
},
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
)
agent = create_agent()

View file

@ -0,0 +1,30 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
"""Redact sensitive information from config before printing."""
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
def _redact_value(v: Any) -> Any:
if isinstance(v, dict):
return _redact_dict(v)
elif isinstance(v, list):
return [_redact_value(i) for i in v]
return v
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
result = {}
for k, v in d.items():
if any(pattern in k.lower() for pattern in sensitive_patterns):
result[k] = "********"
else:
result[k] = _redact_value(v)
return result
return _redact_dict(data)

View file

@ -204,7 +204,9 @@ class ToolUtils:
return None
elif is_json(message_body):
response = json.loads(message_body)
if ("type" in response and response["type"] == "function") or ("name" in response):
if ("type" in response and response["type"] == "function") or (
"name" in response and "parameters" in response
):
function_name = response["name"]
args = response["parameters"]
return function_name, args

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, List, Optional, Protocol
from urllib.parse import urlparse
@ -201,3 +202,12 @@ def remote_provider_spec(
adapter=adapter,
api_dependencies=api_dependencies or [],
)
class HealthStatus(str, Enum):
OK = "OK"
ERROR = "Error"
NOT_IMPLEMENTED = "Not Implemented"
HealthResponse = dict[str, Any]

View file

@ -59,8 +59,8 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
augment_content_with_response_format_prompt,
@ -83,8 +83,8 @@ def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_
class MetaReferenceInferenceImpl(
OpenAICompletionUnsupportedMixin,
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionToLlamaStackMixin,
OpenAIChatCompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin,
Inference,
ModelsProtocolPrivate,

View file

@ -25,8 +25,8 @@ from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
)
from .config import SentenceTransformersInferenceConfig
@ -35,8 +35,8 @@ log = logging.getLogger(__name__)
class SentenceTransformersInferenceImpl(
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin,
Inference,
ModelsProtocolPrivate,

View file

@ -66,10 +66,10 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelsProtocolPrivate,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
OpenAICompletionUnsupportedMixin,
OpenAICompletionToLlamaStackMixin,
get_stop_reason,
process_chat_completion_stream_response,
)
@ -176,8 +176,8 @@ def _convert_sampling_params(
class VLLMInferenceImpl(
Inference,
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
ModelsProtocolPrivate,
):
"""

View file

@ -3,13 +3,14 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, Optional
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
AlgorithmConfig,
Checkpoint,
DPOAlignmentConfig,
JobStatus,
ListPostTrainingJobsResponse,
@ -25,9 +26,19 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
LoraFinetuningSingleDevice,
)
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
from llama_stack.schema_utils import webmethod
class TrainingArtifactType(Enum):
CHECKPOINT = "checkpoint"
RESOURCES_STATS = "resources_stats"
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
class TorchtunePostTrainingImpl:
def __init__(
self,
@ -38,13 +49,27 @@ class TorchtunePostTrainingImpl:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets
self._scheduler = Scheduler()
# TODO: assume sync job, will need jobs API for async scheduling
self.jobs = {}
self.checkpoints_dict = {}
async def shutdown(self) -> None:
await self._scheduler.shutdown()
async def shutdown(self):
pass
@staticmethod
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.CHECKPOINT.value,
name=checkpoint.identifier,
uri=checkpoint.path,
metadata=dict(checkpoint),
)
@staticmethod
def _resources_stats_to_artifact(resources_stats: Dict[str, Any]) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.RESOURCES_STATS.value,
name=TrainingArtifactType.RESOURCES_STATS.value,
metadata=resources_stats,
)
async def supervised_fine_tune(
self,
@ -56,20 +81,11 @@ class TorchtunePostTrainingImpl:
checkpoint_dir: Optional[str],
algorithm_config: Optional[AlgorithmConfig],
) -> PostTrainingJob:
if job_uuid in self.jobs:
raise ValueError(f"Job {job_uuid} already exists")
post_training_job = PostTrainingJob(job_uuid=job_uuid)
job_status_response = PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=JobStatus.scheduled,
scheduled_at=datetime.now(timezone.utc),
)
self.jobs[job_uuid] = job_status_response
if isinstance(algorithm_config, LoraFinetuningConfig):
try:
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
on_log_message_cb("Starting Lora finetuning")
recipe = LoraFinetuningSingleDevice(
self.config,
job_uuid,
@ -82,26 +98,22 @@ class TorchtunePostTrainingImpl:
self.datasetio_api,
self.datasets_api,
)
job_status_response.status = JobStatus.in_progress
job_status_response.started_at = datetime.now(timezone.utc)
await recipe.setup()
resources_allocated, checkpoints = await recipe.train()
self.checkpoints_dict[job_uuid] = checkpoints
job_status_response.resources_allocated = resources_allocated
job_status_response.checkpoints = checkpoints
job_status_response.status = JobStatus.completed
job_status_response.completed_at = datetime.now(timezone.utc)
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
for checkpoint in checkpoints:
artifact = self._checkpoint_to_artifact(checkpoint)
on_artifact_collected_cb(artifact)
except Exception:
job_status_response.status = JobStatus.failed
raise
on_status_change_cb(SchedulerJobStatus.completed)
on_log_message_cb("Lora finetuning completed")
else:
raise NotImplementedError()
return post_training_job
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
return PostTrainingJob(job_uuid=job_uuid)
async def preference_optimize(
self,
@ -114,19 +126,55 @@ class TorchtunePostTrainingImpl:
) -> PostTrainingJob: ...
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
return ListPostTrainingJobsResponse(data=[PostTrainingJob(job_uuid=uuid_) for uuid_ in self.jobs])
return ListPostTrainingJobsResponse(
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
)
@staticmethod
def _get_artifacts_metadata_by_type(job, artifact_type):
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
@classmethod
def _get_checkpoints(cls, job):
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
@classmethod
def _get_resources_allocated(cls, job):
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
return data[0] if data else None
@webmethod(route="/post-training/job/status")
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
return self.jobs.get(job_uuid, None)
job = self._scheduler.get_job(job_uuid)
match job.status:
# TODO: Add support for other statuses to API
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
status = JobStatus.scheduled
case SchedulerJobStatus.running:
status = JobStatus.in_progress
case SchedulerJobStatus.completed:
status = JobStatus.completed
case SchedulerJobStatus.failed:
status = JobStatus.failed
case _:
raise NotImplementedError()
return PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=status,
scheduled_at=job.scheduled_at,
started_at=job.started_at,
completed_at=job.completed_at,
checkpoints=self._get_checkpoints(job),
resources_allocated=self._get_resources_allocated(job),
)
@webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
self._scheduler.cancel(job_uuid)
@webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
if job_uuid in self.checkpoints_dict:
checkpoints = self.checkpoints_dict.get(job_uuid, [])
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=checkpoints)
return None
job = self._scheduler.get_job(job_uuid)
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))

View file

@ -36,10 +36,10 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
OpenAICompletionUnsupportedMixin,
OpenAICompletionToLlamaStackMixin,
get_sampling_strategy_options,
process_chat_completion_response,
process_chat_completion_stream_response,
@ -56,8 +56,8 @@ from .models import MODEL_ENTRIES
class BedrockInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)

View file

@ -34,8 +34,8 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
@ -54,8 +54,8 @@ from .models import MODEL_ENTRIES
class CerebrasInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: CerebrasImplConfig) -> None:
ModelRegistryHelper.__init__(

View file

@ -34,8 +34,8 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
@ -61,8 +61,8 @@ model_entries = [
class DatabricksInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: DatabricksImplConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=model_entries)

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from fireworks.client import Fireworks
from openai import AsyncOpenAI
@ -32,13 +32,20 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
convert_message_to_openai_dict,
get_sampling_options,
prepare_openai_completion_params,
@ -301,6 +308,11 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
# Fireworks always prepends with BOS
if isinstance(prompt, str) and prompt.startswith("<|begin_of_text|>"):
prompt = prompt[len("<|begin_of_text|>") :]
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
@ -320,6 +332,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
top_p=top_p,
user=user,
)
return await self._get_openai_client().completions.create(**params)
async def openai_chat_completion(
@ -336,7 +349,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
@ -347,10 +360,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> OpenAIChatCompletion:
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
@ -374,4 +386,12 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
top_p=top_p,
user=user,
)
return await self._get_openai_client().chat.completions.create(**params)
# Divert Llama Models through Llama Stack inference APIs because
# Fireworks chat completions OpenAI-compatible API does not support
# tool calls properly.
llama_model = self.get_llama_model(model_obj.provider_resource_id)
if llama_model:
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(self, model=model, **params)
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)

View file

@ -4,8 +4,24 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncIterator, Dict, List, Optional, Union
from openai import AsyncOpenAI
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChoiceDelta,
OpenAIChunkChoice,
OpenAIMessageParam,
OpenAIResponseFormatParam,
OpenAISystemMessageParam,
)
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_compat import (
prepare_openai_completion_params,
)
from .models import MODEL_ENTRIES
@ -21,9 +37,129 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
provider_data_api_key_field="groq_api_key",
)
self.config = config
self._openai_client = None
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()
if self._openai_client:
await self._openai_client.close()
self._openai_client = None
def _get_openai_client(self) -> AsyncOpenAI:
if not self._openai_client:
self._openai_client = AsyncOpenAI(
base_url=f"{self.config.url}/openai/v1",
api_key=self.config.api_key,
)
return self._openai_client
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model)
# Groq does not support json_schema response format, so we need to convert it to json_object
if response_format and response_format.type == "json_schema":
response_format.type = "json_object"
schema = response_format.json_schema.get("schema", {})
response_format.json_schema = None
json_instructions = f"\nYour response should be a JSON object that matches the following schema: {schema}"
if messages and messages[0].role == "system":
messages[0].content = messages[0].content + json_instructions
else:
messages.insert(0, OpenAISystemMessageParam(content=json_instructions))
# Groq returns a 400 error if tools are provided but none are called
# So, set tool_choice to "required" to attempt to force a call
if tools and (not tool_choice or tool_choice == "auto"):
tool_choice = "required"
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id.replace("groq/", ""),
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
# Groq does not support streaming requests that set response_format
fake_stream = False
if stream and response_format:
params["stream"] = False
fake_stream = True
response = await self._get_openai_client().chat.completions.create(**params)
if fake_stream:
chunk_choices = []
for choice in response.choices:
delta = OpenAIChoiceDelta(
content=choice.message.content,
role=choice.message.role,
tool_calls=choice.message.tool_calls,
)
chunk_choice = OpenAIChunkChoice(
delta=delta,
finish_reason=choice.finish_reason,
index=choice.index,
logprobs=None,
)
chunk_choices.append(chunk_choice)
chunk = OpenAIChatCompletionChunk(
id=response.id,
choices=chunk_choices,
object="chat.completion.chunk",
created=response.created,
model=response.model,
)
async def _fake_stream_generator():
yield chunk
return _fake_stream_generator()
else:
return response

View file

@ -39,8 +39,16 @@ MODEL_ENTRIES = [
"groq/llama-4-scout-17b-16e-instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"groq/meta-llama/llama-4-scout-17b-16e-instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"groq/llama-4-maverick-17b-128e-instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
build_hf_repo_model_entry(
"groq/meta-llama/llama-4-maverick-17b-128e-instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
]

View file

@ -34,15 +34,18 @@ from llama_stack.apis.inference import (
ToolChoice,
ToolConfig,
)
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.providers.utils.inference import (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.models.llama.datatypes import ToolPromptFormat
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
@ -335,7 +338,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
@ -346,7 +349,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> OpenAIChatCompletion:
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
provider_model_id = self.get_provider_model_id(model)
params = await prepare_openai_completion_params(

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
import httpx
from ollama import AsyncClient
@ -39,10 +39,20 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
ModelsProtocolPrivate,
)
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
@ -87,8 +97,19 @@ class OllamaInferenceAdapter(
async def initialize(self) -> None:
logger.info(f"checking connectivity to Ollama at `{self.url}`...")
await self.health()
async def health(self) -> HealthResponse:
"""
Performs a health check by verifying connectivity to the Ollama server.
This method is used by initialize() and the Provider API to verify that the service is running
correctly.
Returns:
HealthResponse: A dictionary containing the health status.
"""
try:
await self.client.ps()
return HealthResponse(status=HealthStatus.OK)
except httpx.ConnectError as e:
raise RuntimeError(
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
@ -322,6 +343,12 @@ class OllamaInferenceAdapter(
response = await self.client.list()
available_models = [m["model"] for m in response["models"]]
if model.provider_resource_id not in available_models:
available_models_latest = [m["model"].split(":latest")[0] for m in response["models"]]
if model.provider_resource_id in available_models_latest:
logger.warning(
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
)
return model
raise ValueError(
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
)
@ -393,7 +420,7 @@ class OllamaInferenceAdapter(
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
@ -404,7 +431,7 @@ class OllamaInferenceAdapter(
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> OpenAIChatCompletion:
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self._get_model(model)
params = {
k: v

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from llama_stack_client import AsyncLlamaStackClient
@ -26,7 +26,13 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model
from llama_stack.distribution.library_client import convert_pydantic_to_json_value, convert_to_pydantic
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
@ -266,7 +272,7 @@ class PassthroughInferenceAdapter(Inference):
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
@ -277,7 +283,7 @@ class PassthroughInferenceAdapter(Inference):
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> OpenAIChatCompletion:
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
client = self._get_client()
model_obj = await self.model_store.get_model(model)

View file

@ -12,8 +12,8 @@ from llama_stack.apis.inference import * # noqa: F403
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
@ -43,8 +43,8 @@ RUNPOD_SUPPORTED_MODELS = {
class RunpodInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: RunpodImplConfig) -> None:
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)

View file

@ -42,8 +42,8 @@ from llama_stack.apis.inference import (
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
@ -57,8 +57,8 @@ from .models import MODEL_ENTRIES
class SambaNovaInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: SambaNovaImplConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)

View file

@ -40,10 +40,10 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
OpenAICompletionUnsupportedMixin,
OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
@ -73,8 +73,8 @@ def build_hf_repo_model_entries():
class _HfAdapter(
Inference,
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
ModelsProtocolPrivate,
):
client: AsyncInferenceClient

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from openai import AsyncOpenAI
from together import AsyncTogether
@ -31,7 +31,13 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
@ -315,7 +321,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
@ -326,7 +332,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> OpenAIChatCompletion:
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
@ -353,4 +359,26 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
top_p=top_p,
user=user,
)
if params.get("stream", True):
return self._stream_openai_chat_completion(params)
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
# together.ai sometimes adds usage data to the stream, even if include_usage is False
# This causes an unexpected final chunk with empty choices array to be sent
# to clients that may not handle it gracefully.
include_usage = False
if params.get("stream_options", None):
include_usage = params["stream_options"].get("include_usage", False)
stream = await self._get_openai_client().chat.completions.create(**params)
seen_finish_reason = False
async for chunk in stream:
# Final usage chunk with no choices that the user didn't request, so discard
if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
break
yield chunk
for choice in chunk.choices:
if choice.finish_reason:
seen_finish_reason = True
break

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import json
import logging
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
import httpx
from openai import AsyncOpenAI
@ -45,7 +45,12 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from llama_stack.models.llama.sku_list import all_registered_models
@ -369,7 +374,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
options["max_tokens"] = self.config.max_tokens
input_dict: dict[str, Any] = {}
if isinstance(request, ChatCompletionRequest) and request.tools is not None:
# Only include the 'tools' param if there is any. It can break things if an empty list is sent to the vLLM.
if isinstance(request, ChatCompletionRequest) and request.tools:
input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)}
if isinstance(request, ChatCompletionRequest):
@ -487,7 +493,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
@ -498,7 +504,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> OpenAIChatCompletion:
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self._get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,

View file

@ -30,7 +30,13 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models.models import Model
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
@ -270,7 +276,7 @@ class LiteLLMOpenAIMixin(
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
model_obj = await self._get_model(model)
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
@ -292,7 +298,7 @@ class LiteLLMOpenAIMixin(
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
)
return litellm.text_completion(**params)
return await litellm.atext_completion(**params)
async def openai_chat_completion(
self,
@ -308,7 +314,7 @@ class LiteLLMOpenAIMixin(
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
@ -319,8 +325,8 @@ class LiteLLMOpenAIMixin(
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> OpenAIChatCompletion:
model_obj = await self._get_model(model)
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
@ -346,7 +352,7 @@ class LiteLLMOpenAIMixin(
top_p=top_p,
user=user,
)
return litellm.completion(**params)
return await litellm.acompletion(**params)
async def batch_completion(
self,

View file

@ -8,7 +8,7 @@ import logging
import time
import uuid
import warnings
from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Union
from typing import Any, AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable, List, Optional, Union
from openai import AsyncStream
from openai.types.chat import (
@ -50,6 +50,18 @@ from openai.types.chat.chat_completion import (
from openai.types.chat.chat_completion import (
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
)
from openai.types.chat.chat_completion_chunk import (
Choice as OpenAIChatCompletionChunkChoice,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDelta as OpenAIChoiceDelta,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction,
)
from openai.types.chat.chat_completion_content_part_image_param import (
ImageURL as OpenAIImageURL,
)
@ -59,6 +71,7 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
from pydantic import BaseModel
from llama_stack.apis.common.content_types import (
URL,
ImageContentItem,
InterleavedContent,
TextContentItem,
@ -85,12 +98,24 @@ from llama_stack.apis.inference import (
TopPSamplingStrategy,
UserMessage,
)
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAICompletionChoice
from llama_stack.apis.inference.inference import (
JsonSchemaResponseFormat,
OpenAIChatCompletion,
OpenAICompletion,
OpenAICompletionChoice,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ToolConfig,
)
from llama_stack.apis.inference.inference import (
OpenAIChoice as OpenAIChatCompletionChoice,
)
from llama_stack.models.llama.datatypes import (
BuiltinTool,
StopReason,
ToolCall,
ToolDefinition,
ToolParamDefinition,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url,
@ -751,6 +776,17 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
return out
def _convert_stop_reason_to_openai_finish_reason(stop_reason: StopReason) -> str:
"""
Convert a StopReason to an OpenAI chat completion finish_reason.
"""
return {
StopReason.end_of_turn: "stop",
StopReason.end_of_message: "tool_calls",
StopReason.out_of_tokens: "length",
}.get(stop_reason, "stop")
def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
"""
Convert an OpenAI chat completion finish_reason to a StopReason.
@ -776,6 +812,56 @@ def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
}.get(finish_reason, StopReason.end_of_turn)
def _convert_openai_request_tool_config(tool_choice: Optional[Union[str, Dict[str, Any]]] = None) -> ToolConfig:
tool_config = ToolConfig()
if tool_choice:
tool_config.tool_choice = tool_choice
return tool_config
def _convert_openai_request_tools(tools: Optional[List[Dict[str, Any]]] = None) -> List[ToolDefinition]:
lls_tools = []
if not tools:
return lls_tools
for tool in tools:
tool_fn = tool.get("function", {})
tool_name = tool_fn.get("name", None)
tool_desc = tool_fn.get("description", None)
tool_params = tool_fn.get("parameters", None)
lls_tool_params = {}
if tool_params is not None:
tool_param_properties = tool_params.get("properties", {})
for tool_param_key, tool_param_value in tool_param_properties.items():
tool_param_def = ToolParamDefinition(
param_type=tool_param_value.get("type", None),
description=tool_param_value.get("description", None),
)
lls_tool_params[tool_param_key] = tool_param_def
lls_tool = ToolDefinition(
tool_name=tool_name,
description=tool_desc,
parameters=lls_tool_params,
)
lls_tools.append(lls_tool)
return lls_tools
def _convert_openai_request_response_format(response_format: OpenAIResponseFormatParam = None):
if not response_format:
return None
# response_format can be a dict or a pydantic model
response_format = dict(response_format)
if response_format.get("type", "") == "json_schema":
return JsonSchemaResponseFormat(
type="json_schema",
json_schema=response_format.get("json_schema", {}).get("schema", ""),
)
return None
def _convert_openai_tool_calls(
tool_calls: List[OpenAIChatCompletionMessageToolCall],
) -> List[ToolCall]:
@ -871,6 +957,40 @@ def _convert_openai_sampling_params(
return sampling_params
def _convert_openai_request_messages(messages: List[OpenAIMessageParam]):
# Llama Stack messages and OpenAI messages are similar, but not identical.
lls_messages = []
for message in messages:
lls_message = dict(message)
# Llama Stack expects `call_id` but OpenAI uses `tool_call_id`
tool_call_id = lls_message.pop("tool_call_id", None)
if tool_call_id:
lls_message["call_id"] = tool_call_id
content = lls_message.get("content", None)
if isinstance(content, list):
lls_content = []
for item in content:
# items can either by pydantic models or dicts here...
item = dict(item)
if item.get("type", "") == "image_url":
lls_item = ImageContentItem(
type="image",
image=URL(uri=item.get("image_url", {}).get("url", "")),
)
elif item.get("type", "") == "text":
lls_item = TextContentItem(
type="text",
text=item.get("text", ""),
)
lls_content.append(lls_item)
lls_message["content"] = lls_content
lls_messages.append(lls_message)
return lls_messages
def convert_openai_chat_completion_choice(
choice: OpenAIChoice,
) -> ChatCompletionResponse:
@ -1080,11 +1200,24 @@ async def convert_openai_chat_completion_stream(
async def prepare_openai_completion_params(**params):
completion_params = {k: v for k, v in params.items() if v is not None}
async def _prepare_value(value: Any) -> Any:
new_value = value
if isinstance(value, list):
new_value = [await _prepare_value(v) for v in value]
elif isinstance(value, dict):
new_value = {k: await _prepare_value(v) for k, v in value.items()}
elif isinstance(value, BaseModel):
new_value = value.model_dump(exclude_none=True)
return new_value
completion_params = {}
for k, v in params.items():
if v is not None:
completion_params[k] = await _prepare_value(v)
return completion_params
class OpenAICompletionUnsupportedMixin:
class OpenAICompletionToLlamaStackMixin:
async def openai_completion(
self,
model: str,
@ -1122,6 +1255,7 @@ class OpenAICompletionUnsupportedMixin:
choices = []
# "n" is the number of completions to generate per prompt
n = n or 1
for _i in range(0, n):
# and we may have multiple prompts, if batching was used
@ -1134,7 +1268,7 @@ class OpenAICompletionUnsupportedMixin:
index = len(choices)
text = result.content
finish_reason = _convert_openai_finish_reason(result.stop_reason)
finish_reason = _convert_stop_reason_to_openai_finish_reason(result.stop_reason)
choice = OpenAICompletionChoice(
index=index,
@ -1152,7 +1286,7 @@ class OpenAICompletionUnsupportedMixin:
)
class OpenAIChatCompletionUnsupportedMixin:
class OpenAIChatCompletionToLlamaStackMixin:
async def openai_chat_completion(
self,
model: str,
@ -1167,7 +1301,7 @@ class OpenAIChatCompletionUnsupportedMixin:
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
@ -1178,5 +1312,103 @@ class OpenAIChatCompletionUnsupportedMixin:
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
messages = _convert_openai_request_messages(messages)
response_format = _convert_openai_request_response_format(response_format)
sampling_params = _convert_openai_sampling_params(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
tool_config = _convert_openai_request_tool_config(tool_choice)
tools = _convert_openai_request_tools(tools)
outstanding_responses = []
# "n" is the number of completions to generate per prompt
n = n or 1
for _i in range(0, n):
response = self.chat_completion(
model_id=model,
messages=messages,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
tool_config=tool_config,
tools=tools,
)
outstanding_responses.append(response)
if stream:
return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses)
return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response(
self, model, outstanding_responses
)
async def _process_stream_response(
self, model: str, outstanding_responses: List[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]]
):
id = f"chatcmpl-{uuid.uuid4()}"
for outstanding_response in outstanding_responses:
response = await outstanding_response
i = 0
async for chunk in response:
event = chunk.event
finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
if isinstance(event.delta, TextDelta):
text_delta = event.delta.text
delta = OpenAIChoiceDelta(content=text_delta)
yield OpenAIChatCompletionChunk(
id=id,
choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)],
created=int(time.time()),
model=model,
object="chat.completion.chunk",
)
elif isinstance(event.delta, ToolCallDelta):
if event.delta.parse_status == ToolCallParseStatus.succeeded:
tool_call = event.delta.tool_call
openai_tool_call = OpenAIChoiceDeltaToolCall(
index=0,
id=tool_call.call_id,
function=OpenAIChoiceDeltaToolCallFunction(
name=tool_call.tool_name, arguments=tool_call.arguments_json
),
)
delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call])
yield OpenAIChatCompletionChunk(
id=id,
choices=[
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
],
created=int(time.time()),
model=model,
object="chat.completion.chunk",
)
i = i + 1
async def _process_non_stream_response(
self, model: str, outstanding_responses: List[Awaitable[ChatCompletionResponse]]
) -> OpenAIChatCompletion:
raise ValueError(f"{self.__class__.__name__} doesn't support openai chat completion")
choices = []
for outstanding_response in outstanding_responses:
response = await outstanding_response
completion_message = response.completion_message
message = await convert_message_to_openai_dict_new(completion_message)
finish_reason = _convert_stop_reason_to_openai_finish_reason(completion_message.stop_reason)
choice = OpenAIChatCompletionChoice(
index=len(choices),
message=message,
finish_reason=finish_reason,
)
choices.append(choice)
return OpenAIChatCompletion(
id=f"chatcmpl-{uuid.uuid4()}",
choices=choices,
created=int(time.time()),
model=model,
object="chat.completion",
)

View file

@ -0,0 +1,265 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import abc
import asyncio
import functools
import threading
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Callable, Coroutine, Dict, Iterable, Tuple, TypeAlias
from pydantic import BaseModel
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="scheduler")
# TODO: revisit the list of possible statuses when defining a more coherent
# Jobs API for all API flows; e.g. do we need new vs scheduled?
class JobStatus(Enum):
new = "new"
scheduled = "scheduled"
running = "running"
failed = "failed"
completed = "completed"
JobID: TypeAlias = str
JobType: TypeAlias = str
class JobArtifact(BaseModel):
type: JobType
name: str
# TODO: uri should be a reference to /files API; revisit when /files is implemented
uri: str | None = None
metadata: Dict[str, Any]
JobHandler = Callable[
[Callable[[str], None], Callable[[JobStatus], None], Callable[[JobArtifact], None]], Coroutine[Any, Any, None]
]
LogMessage: TypeAlias = Tuple[datetime, str]
_COMPLETED_STATUSES = {JobStatus.completed, JobStatus.failed}
class Job:
def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler):
super().__init__()
self.id = job_id
self._type = job_type
self._handler = handler
self._artifacts: list[JobArtifact] = []
self._logs: list[LogMessage] = []
self._state_transitions: list[Tuple[datetime, JobStatus]] = [(datetime.now(timezone.utc), JobStatus.new)]
@property
def handler(self) -> JobHandler:
return self._handler
@property
def status(self) -> JobStatus:
return self._state_transitions[-1][1]
@status.setter
def status(self, status: JobStatus):
if status in _COMPLETED_STATUSES and self.status in _COMPLETED_STATUSES:
raise ValueError(f"Job is already in a completed state ({self.status})")
if self.status == status:
return
self._state_transitions.append((datetime.now(timezone.utc), status))
@property
def artifacts(self) -> list[JobArtifact]:
return self._artifacts
def register_artifact(self, artifact: JobArtifact) -> None:
self._artifacts.append(artifact)
def _find_state_transition_date(self, status: Iterable[JobStatus]) -> datetime | None:
for date, s in reversed(self._state_transitions):
if s in status:
return date
return None
@property
def scheduled_at(self) -> datetime | None:
return self._find_state_transition_date([JobStatus.scheduled])
@property
def started_at(self) -> datetime | None:
return self._find_state_transition_date([JobStatus.running])
@property
def completed_at(self) -> datetime | None:
return self._find_state_transition_date(_COMPLETED_STATUSES)
@property
def logs(self) -> list[LogMessage]:
return self._logs[:]
def append_log(self, message: LogMessage) -> None:
self._logs.append(message)
# TODO: implement
def cancel(self) -> None:
raise NotImplementedError
class _SchedulerBackend(abc.ABC):
@abc.abstractmethod
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
raise NotImplementedError
@abc.abstractmethod
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
raise NotImplementedError
@abc.abstractmethod
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
raise NotImplementedError
@abc.abstractmethod
async def shutdown(self) -> None:
raise NotImplementedError
@abc.abstractmethod
def schedule(
self,
job: Job,
on_log_message_cb: Callable[[str], None],
on_status_change_cb: Callable[[JobStatus], None],
on_artifact_collected_cb: Callable[[JobArtifact], None],
) -> None:
raise NotImplementedError
class _NaiveSchedulerBackend(_SchedulerBackend):
def __init__(self, timeout: int = 5):
self._timeout = timeout
self._loop = asyncio.new_event_loop()
# There may be performance implications of using threads due to Python
# GIL; may need to measure if it's a real problem though
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
def _run_loop(self) -> None:
asyncio.set_event_loop(self._loop)
self._loop.run_forever()
# When stopping the loop, give tasks a chance to finish
# TODO: should we explicitly inform jobs of pending stoppage?
for task in asyncio.all_tasks(self._loop):
self._loop.run_until_complete(task)
self._loop.close()
async def shutdown(self) -> None:
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join()
# TODO: decouple scheduling and running the job
def schedule(
self,
job: Job,
on_log_message_cb: Callable[[str], None],
on_status_change_cb: Callable[[JobStatus], None],
on_artifact_collected_cb: Callable[[JobArtifact], None],
) -> None:
async def do():
try:
job.status = JobStatus.running
await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb)
except Exception as e:
on_log_message_cb(str(e))
job.status = JobStatus.failed
logger.exception(f"Job {job.id} failed.")
asyncio.run_coroutine_threadsafe(do(), self._loop)
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
pass
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
pass
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
pass
_BACKENDS = {
"naive": _NaiveSchedulerBackend,
}
def _get_backend_impl(backend: str) -> _SchedulerBackend:
try:
return _BACKENDS[backend]()
except KeyError as e:
raise ValueError(f"Unknown backend {backend}") from e
class Scheduler:
def __init__(self, backend: str = "naive"):
# TODO: if server crashes, job states are lost; we need to persist jobs on disc
self._jobs: dict[JobID, Job] = {}
self._backend = _get_backend_impl(backend)
def _on_log_message_cb(self, job: Job, message: str) -> None:
msg = (datetime.now(timezone.utc), message)
# At least for the time being, until there's a better way to expose
# logs to users, log messages on console
logger.info(f"Job {job.id}: {message}")
job.append_log(msg)
self._backend.on_log_message_cb(job, msg)
def _on_status_change_cb(self, job: Job, status: JobStatus) -> None:
job.status = status
self._backend.on_status_change_cb(job, status)
def _on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
job.register_artifact(artifact)
self._backend.on_artifact_collected_cb(job, artifact)
def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler) -> JobID:
job = Job(type_, job_id, handler)
if job.id in self._jobs:
raise ValueError(f"Job {job.id} already exists")
self._jobs[job.id] = job
job.status = JobStatus.scheduled
self._backend.schedule(
job,
functools.partial(self._on_log_message_cb, job),
functools.partial(self._on_status_change_cb, job),
functools.partial(self._on_artifact_collected_cb, job),
)
return job.id
def cancel(self, job_id: JobID) -> None:
self.get_job(job_id).cancel()
def get_job(self, job_id: JobID) -> Job:
try:
return self._jobs[job_id]
except KeyError as e:
raise ValueError(f"Job {job_id} not found") from e
def get_jobs(self, type_: JobType | None = None) -> list[Job]:
jobs = list(self._jobs.values())
if type_:
jobs = [job for job in jobs if job._type == type_]
return jobs
async def shutdown(self):
# TODO: also cancel jobs once implemented
await self._backend.shutdown()

View file

@ -386,6 +386,16 @@ models:
provider_id: groq
provider_model_id: groq/llama-4-scout-17b-16e-instruct
model_type: llm
- metadata: {}
model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
provider_id: groq
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
provider_id: groq
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
model_type: llm
- metadata: {}
model_id: groq/llama-4-maverick-17b-128e-instruct
provider_id: groq
@ -396,6 +406,16 @@ models:
provider_id: groq
provider_model_id: groq/llama-4-maverick-17b-128e-instruct
model_type: llm
- metadata: {}
model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
provider_id: groq
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
provider_id: groq
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2

View file

@ -158,6 +158,16 @@ models:
provider_id: groq
provider_model_id: groq/llama-4-scout-17b-16e-instruct
model_type: llm
- metadata: {}
model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
provider_id: groq
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
provider_id: groq
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
model_type: llm
- metadata: {}
model_id: groq/llama-4-maverick-17b-128e-instruct
provider_id: groq
@ -168,6 +178,16 @@ models:
provider_id: groq
provider_model_id: groq/llama-4-maverick-17b-128e-instruct
model_type: llm
- metadata: {}
model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
provider_id: groq
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
provider_id: groq
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2

View file

@ -28,7 +28,7 @@ The following environment variables can be configured:
## Setting up vLLM server
In the following sections, we'll use either AMD and NVIDIA GPUs to serve as hardware accelerators for the vLLM
In the following sections, we'll use AMD, NVIDIA or Intel GPUs to serve as hardware accelerators for the vLLM
server, which acts as both the LLM inference provider and the safety provider. Note that vLLM also
[supports many other hardware accelerators](https://docs.vllm.ai/en/latest/getting_started/installation.html) and
that we only use GPUs here for demonstration purposes.
@ -149,6 +149,55 @@ docker run \
--port $SAFETY_PORT
```
### Setting up vLLM server on Intel GPU
Refer to [vLLM Documentation for XPU](https://docs.vllm.ai/en/v0.8.2/getting_started/installation/gpu.html?device=xpu) to get a vLLM endpoint. In addition to vLLM side setup which guides towards installing vLLM from sources orself-building vLLM Docker container, Intel provides prebuilt vLLM container to use on systems with Intel GPUs supported by PyTorch XPU backend:
- [intel/vllm](https://hub.docker.com/r/intel/vllm)
Here is a sample script to start a vLLM server locally via Docker using Intel provided container:
```bash
export INFERENCE_PORT=8000
export INFERENCE_MODEL=meta-llama/Llama-3.2-1B-Instruct
export ZE_AFFINITY_MASK=0
docker run \
--pull always \
--device /dev/dri \
-v /dev/dri/by-path:/dev/dri/by-path \
-v ~/.cache/huggingface:/root/.cache/huggingface \
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
--env ZE_AFFINITY_MASK=$ZE_AFFINITY_MASK \
-p $INFERENCE_PORT:$INFERENCE_PORT \
--ipc=host \
intel/vllm:xpu \
--gpu-memory-utilization 0.7 \
--model $INFERENCE_MODEL \
--port $INFERENCE_PORT
```
If you are using Llama Stack Safety / Shield APIs, then you will need to also run another instance of a vLLM with a corresponding safety model like `meta-llama/Llama-Guard-3-1B` using a script like:
```bash
export SAFETY_PORT=8081
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
export ZE_AFFINITY_MASK=1
docker run \
--pull always \
--device /dev/dri \
-v /dev/dri/by-path:/dev/dri/by-path \
-v ~/.cache/huggingface:/root/.cache/huggingface \
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
--env ZE_AFFINITY_MASK=$ZE_AFFINITY_MASK \
-p $SAFETY_PORT:$SAFETY_PORT \
--ipc=host \
intel/vllm:xpu \
--gpu-memory-utilization 0.7 \
--model $SAFETY_MODEL \
--port $SAFETY_PORT
```
## Running Llama Stack
Now you are ready to run Llama Stack with vLLM as the inference provider. You can do this via Conda (build code) or Docker which has a pre-built image.

View file

@ -474,6 +474,16 @@ models:
provider_id: groq-openai-compat
provider_model_id: groq/llama-4-scout-17b-16e-instruct
model_type: llm
- metadata: {}
model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
provider_id: groq-openai-compat
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
provider_id: groq-openai-compat
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
model_type: llm
- metadata: {}
model_id: groq/llama-4-maverick-17b-128e-instruct
provider_id: groq-openai-compat
@ -484,6 +494,16 @@ models:
provider_id: groq-openai-compat
provider_model_id: groq/llama-4-maverick-17b-128e-instruct
model_type: llm
- metadata: {}
model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
provider_id: groq-openai-compat
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
provider_id: groq-openai-compat
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
model_type: llm
- metadata: {}
model_id: Meta-Llama-3.1-8B-Instruct
provider_id: sambanova-openai-compat

View file

@ -115,7 +115,7 @@ def test_openai_completion_streaming(openai_client, client_with_models, text_mod
stream=True,
max_tokens=50,
)
streamed_content = [chunk.choices[0].text for chunk in response]
streamed_content = [chunk.choices[0].text or "" for chunk in response]
content_str = "".join(streamed_content).lower().strip()
assert len(content_str) > 10

View file

@ -26,7 +26,12 @@ from openai.types.chat.chat_completion_chunk import (
)
from openai.types.model import Model as OpenAIModel
from llama_stack.apis.inference import ToolChoice, ToolConfig
from llama_stack.apis.inference import (
ChatCompletionRequest,
ToolChoice,
ToolConfig,
UserMessage,
)
from llama_stack.apis.models import Model
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
@ -232,3 +237,14 @@ def test_chat_completion_doesnt_block_event_loop(caplog):
# above.
asyncio_warnings = [record.message for record in caplog.records if record.name == "asyncio"]
assert not asyncio_warnings
@pytest.mark.asyncio
async def test_get_params_empty_tools(vllm_inference_adapter):
request = ChatCompletionRequest(
tools=[],
model="test_model",
messages=[UserMessage(content="test")],
)
params = await vllm_inference_adapter._get_params(request)
assert "tools" not in params

View file

@ -0,0 +1,120 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import pytest
from llama_stack.providers.utils.scheduler import JobStatus, Scheduler
@pytest.mark.asyncio
async def test_scheduler_unknown_backend():
with pytest.raises(ValueError):
Scheduler(backend="unknown")
@pytest.mark.asyncio
async def test_scheduler_naive():
sched = Scheduler()
# make sure the scheduler starts empty
with pytest.raises(ValueError):
sched.get_job("unknown")
assert sched.get_jobs() == []
called = False
# schedule a job that will exercise the handlers
async def job_handler(on_log, on_status, on_artifact):
nonlocal called
called = True
# exercise the handlers
on_log("test log1")
on_log("test log2")
on_artifact({"type": "type1", "path": "path1"})
on_artifact({"type": "type2", "path": "path2"})
on_status(JobStatus.completed)
job_id = "test_job_id"
job_type = "test_job_type"
sched.schedule(job_type, job_id, job_handler)
# make sure the job was properly registered
with pytest.raises(ValueError):
sched.get_job("unknown")
assert sched.get_job(job_id) is not None
assert sched.get_jobs() == [sched.get_job(job_id)]
assert sched.get_jobs("unknown") == []
assert sched.get_jobs(job_type) == [sched.get_job(job_id)]
# now shut the scheduler down and make sure the job ran
await sched.shutdown()
assert called
job = sched.get_job(job_id)
assert job is not None
assert job.status == JobStatus.completed
assert job.scheduled_at is not None
assert job.started_at is not None
assert job.completed_at is not None
assert job.scheduled_at < job.started_at < job.completed_at
assert job.artifacts == [
{"type": "type1", "path": "path1"},
{"type": "type2", "path": "path2"},
]
assert [msg[1] for msg in job.logs] == ["test log1", "test log2"]
assert job.logs[0][0] < job.logs[1][0]
@pytest.mark.asyncio
async def test_scheduler_naive_handler_raises():
sched = Scheduler()
async def failing_job_handler(on_log, on_status, on_artifact):
on_status(JobStatus.running)
raise ValueError("test error")
job_id = "test_job_id1"
job_type = "test_job_type"
sched.schedule(job_type, job_id, failing_job_handler)
job = sched.get_job(job_id)
assert job is not None
# confirm the exception made the job transition to failed state, even
# though it was set to `running` before the error
for _ in range(10):
if job.status == JobStatus.failed:
break
await asyncio.sleep(0.1)
assert job.status == JobStatus.failed
# confirm that the raised error got registered in log
assert job.logs[0][1] == "test error"
# even after failed job, we can schedule another one
called = False
async def successful_job_handler(on_log, on_status, on_artifact):
nonlocal called
called = True
on_status(JobStatus.completed)
job_id = "test_job_id2"
sched.schedule(job_type, job_id, successful_job_handler)
await sched.shutdown()
assert called
job = sched.get_job(job_id)
assert job is not None
assert job.status == JobStatus.completed

View file

@ -1,6 +1,6 @@
# Test Results Report
*Generated on: 2025-04-10 16:48:18*
*Generated on: 2025-04-14 18:11:37*
*This report was generated by running `python tests/verifications/generate_report.py`*
@ -15,15 +15,15 @@
| Provider | Pass Rate | Tests Passed | Total Tests |
| --- | --- | --- | --- |
| Together | 64.7% | 22 | 34 |
| Fireworks | 82.4% | 28 | 34 |
| Openai | 100.0% | 24 | 24 |
| Together | 48.7% | 37 | 76 |
| Fireworks | 47.4% | 36 | 76 |
| Openai | 100.0% | 52 | 52 |
## Together
*Tests run on: 2025-04-10 16:46:35*
*Tests run on: 2025-04-14 18:08:14*
```bash
# Run all tests for this provider:
@ -48,19 +48,33 @@ pytest tests/verifications/openai_api/test_chat_completion.py --provider=togethe
| test_chat_non_streaming_basic (earth) | ✅ | ✅ | ✅ |
| test_chat_non_streaming_basic (saturn) | ✅ | ✅ | ✅ |
| test_chat_non_streaming_image | ⚪ | ✅ | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (add_product_tool) | ✅ | ✅ | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ❌ | ✅ | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ✅ | ❌ | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (text_then_weather_tool) | ❌ | ❌ | ❌ |
| test_chat_non_streaming_multi_turn_tool_calling (weather_tool_then_text) | ✅ | ✅ | ✅ |
| test_chat_non_streaming_structured_output (calendar) | ✅ | ✅ | ✅ |
| test_chat_non_streaming_structured_output (math) | ✅ | ✅ | ✅ |
| test_chat_non_streaming_tool_calling | ✅ | ✅ | ✅ |
| test_chat_non_streaming_tool_choice_none | ❌ | ❌ | ❌ |
| test_chat_non_streaming_tool_choice_required | ✅ | ✅ | ✅ |
| test_chat_streaming_basic (earth) | ✅ | ❌ | ❌ |
| test_chat_streaming_basic (saturn) | ✅ | ❌ | ❌ |
| test_chat_streaming_image | ⚪ | ❌ | ❌ |
| test_chat_streaming_multi_turn_tool_calling (add_product_tool) | ✅ | ❌ | ❌ |
| test_chat_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ❌ | ❌ | ❌ |
| test_chat_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ❌ | ❌ | ❌ |
| test_chat_streaming_multi_turn_tool_calling (text_then_weather_tool) | ❌ | ❌ | ❌ |
| test_chat_streaming_multi_turn_tool_calling (weather_tool_then_text) | ❌ | ❌ | ❌ |
| test_chat_streaming_structured_output (calendar) | ✅ | ❌ | ❌ |
| test_chat_streaming_structured_output (math) | ✅ | ❌ | ❌ |
| test_chat_streaming_tool_calling | ✅ | ❌ | ❌ |
| test_chat_streaming_tool_choice_none | ❌ | ❌ | ❌ |
| test_chat_streaming_tool_choice_required | ✅ | ❌ | ❌ |
## Fireworks
*Tests run on: 2025-04-10 16:44:44*
*Tests run on: 2025-04-14 18:04:06*
```bash
# Run all tests for this provider:
@ -85,19 +99,33 @@ pytest tests/verifications/openai_api/test_chat_completion.py --provider=firewor
| test_chat_non_streaming_basic (earth) | ✅ | ✅ | ✅ |
| test_chat_non_streaming_basic (saturn) | ✅ | ✅ | ✅ |
| test_chat_non_streaming_image | ⚪ | ✅ | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (add_product_tool) | ❌ | ❌ | ❌ |
| test_chat_non_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ❌ | ❌ | ❌ |
| test_chat_non_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ❌ | ❌ | ❌ |
| test_chat_non_streaming_multi_turn_tool_calling (text_then_weather_tool) | ❌ | ❌ | ❌ |
| test_chat_non_streaming_multi_turn_tool_calling (weather_tool_then_text) | ❌ | ❌ | ❌ |
| test_chat_non_streaming_structured_output (calendar) | ✅ | ✅ | ✅ |
| test_chat_non_streaming_structured_output (math) | ✅ | ✅ | ✅ |
| test_chat_non_streaming_tool_calling | ❌ | ❌ | ❌ |
| test_chat_non_streaming_tool_choice_none | ✅ | ✅ | ✅ |
| test_chat_non_streaming_tool_choice_required | ✅ | ❌ | ❌ |
| test_chat_streaming_basic (earth) | ✅ | ✅ | ✅ |
| test_chat_streaming_basic (saturn) | ✅ | ✅ | ✅ |
| test_chat_streaming_image | ⚪ | ✅ | ✅ |
| test_chat_streaming_multi_turn_tool_calling (add_product_tool) | ❌ | ❌ | ❌ |
| test_chat_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ❌ | ❌ | ❌ |
| test_chat_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ❌ | ❌ | ❌ |
| test_chat_streaming_multi_turn_tool_calling (text_then_weather_tool) | ❌ | ❌ | ❌ |
| test_chat_streaming_multi_turn_tool_calling (weather_tool_then_text) | ❌ | ❌ | ❌ |
| test_chat_streaming_structured_output (calendar) | ✅ | ✅ | ✅ |
| test_chat_streaming_structured_output (math) | ✅ | ✅ | ✅ |
| test_chat_streaming_tool_calling | ❌ | ❌ | ❌ |
| test_chat_streaming_tool_choice_none | ✅ | ✅ | ✅ |
| test_chat_streaming_tool_choice_required | ✅ | ❌ | ❌ |
## Openai
*Tests run on: 2025-04-10 16:47:28*
*Tests run on: 2025-04-14 18:09:51*
```bash
# Run all tests for this provider:
@ -121,12 +149,26 @@ pytest tests/verifications/openai_api/test_chat_completion.py --provider=openai
| test_chat_non_streaming_basic (earth) | ✅ | ✅ |
| test_chat_non_streaming_basic (saturn) | ✅ | ✅ |
| test_chat_non_streaming_image | ✅ | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (add_product_tool) | ✅ | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ✅ | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ✅ | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (text_then_weather_tool) | ✅ | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (weather_tool_then_text) | ✅ | ✅ |
| test_chat_non_streaming_structured_output (calendar) | ✅ | ✅ |
| test_chat_non_streaming_structured_output (math) | ✅ | ✅ |
| test_chat_non_streaming_tool_calling | ✅ | ✅ |
| test_chat_non_streaming_tool_choice_none | ✅ | ✅ |
| test_chat_non_streaming_tool_choice_required | ✅ | ✅ |
| test_chat_streaming_basic (earth) | ✅ | ✅ |
| test_chat_streaming_basic (saturn) | ✅ | ✅ |
| test_chat_streaming_image | ✅ | ✅ |
| test_chat_streaming_multi_turn_tool_calling (add_product_tool) | ✅ | ✅ |
| test_chat_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ✅ | ✅ |
| test_chat_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ✅ | ✅ |
| test_chat_streaming_multi_turn_tool_calling (text_then_weather_tool) | ✅ | ✅ |
| test_chat_streaming_multi_turn_tool_calling (weather_tool_then_text) | ✅ | ✅ |
| test_chat_streaming_structured_output (calendar) | ✅ | ✅ |
| test_chat_streaming_structured_output (math) | ✅ | ✅ |
| test_chat_streaming_tool_calling | ✅ | ✅ |
| test_chat_streaming_tool_choice_none | ✅ | ✅ |
| test_chat_streaming_tool_choice_required | ✅ | ✅ |

View file

@ -0,0 +1,14 @@
base_url: http://localhost:8321/v1/openai/v1
api_key_var: FIREWORKS_API_KEY
models:
- fireworks/llama-v3p3-70b-instruct
- fireworks/llama4-scout-instruct-basic
- fireworks/llama4-maverick-instruct-basic
model_display_names:
fireworks/llama-v3p3-70b-instruct: Llama-3.3-70B-Instruct
fireworks/llama4-scout-instruct-basic: Llama-4-Scout-Instruct
fireworks/llama4-maverick-instruct-basic: Llama-4-Maverick-Instruct
test_exclusions:
fireworks/llama-v3p3-70b-instruct:
- test_chat_non_streaming_image
- test_chat_streaming_image

View file

@ -0,0 +1,14 @@
base_url: http://localhost:8321/v1/openai/v1
api_key_var: GROQ_API_KEY
models:
- groq/llama-3.3-70b-versatile
- groq/llama-4-scout-17b-16e-instruct
- groq/llama-4-maverick-17b-128e-instruct
model_display_names:
groq/llama-3.3-70b-versatile: Llama-3.3-70B-Instruct
groq/llama-4-scout-17b-16e-instruct: Llama-4-Scout-Instruct
groq/llama-4-maverick-17b-128e-instruct: Llama-4-Maverick-Instruct
test_exclusions:
groq/llama-3.3-70b-versatile:
- test_chat_non_streaming_image
- test_chat_streaming_image

View file

@ -2,12 +2,12 @@ base_url: https://api.groq.com/openai/v1
api_key_var: GROQ_API_KEY
models:
- llama-3.3-70b-versatile
- llama-4-scout-17b-16e-instruct
- llama-4-maverick-17b-128e-instruct
- meta-llama/llama-4-scout-17b-16e-instruct
- meta-llama/llama-4-maverick-17b-128e-instruct
model_display_names:
llama-3.3-70b-versatile: Llama-3.3-70B-Instruct
llama-4-scout-17b-16e-instruct: Llama-4-Scout-Instruct
llama-4-maverick-17b-128e-instruct: Llama-4-Maverick-Instruct
meta-llama/llama-4-scout-17b-16e-instruct: Llama-4-Scout-Instruct
meta-llama/llama-4-maverick-17b-128e-instruct: Llama-4-Maverick-Instruct
test_exclusions:
llama-3.3-70b-versatile:
- test_chat_non_streaming_image

View file

@ -0,0 +1,9 @@
base_url: http://localhost:8321/v1/openai/v1
api_key_var: OPENAI_API_KEY
models:
- openai/gpt-4o
- openai/gpt-4o-mini
model_display_names:
openai/gpt-4o: gpt-4o
openai/gpt-4o-mini: gpt-4o-mini
test_exclusions: {}

View file

@ -0,0 +1,14 @@
base_url: http://localhost:8321/v1/openai/v1
api_key_var: TOGETHER_API_KEY
models:
- together/meta-llama/Llama-3.3-70B-Instruct-Turbo
- together/meta-llama/Llama-4-Scout-17B-16E-Instruct
- together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8
model_display_names:
together/meta-llama/Llama-3.3-70B-Instruct-Turbo: Llama-3.3-70B-Instruct
together/meta-llama/Llama-4-Scout-17B-16E-Instruct: Llama-4-Scout-Instruct
together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8: Llama-4-Maverick-Instruct
test_exclusions:
together/meta-llama/Llama-3.3-70B-Instruct-Turbo:
- test_chat_non_streaming_image
- test_chat_streaming_image

View file

@ -67,7 +67,17 @@ RESULTS_DIR.mkdir(exist_ok=True)
# Maximum number of test result files to keep per provider
MAX_RESULTS_PER_PROVIDER = 1
PROVIDER_ORDER = ["together", "fireworks", "groq", "cerebras", "openai"]
PROVIDER_ORDER = [
"together",
"fireworks",
"groq",
"cerebras",
"openai",
"together-llama-stack",
"fireworks-llama-stack",
"groq-llama-stack",
"openai-llama-stack",
]
VERIFICATION_CONFIG = _load_all_verification_configs()

View file

@ -0,0 +1,146 @@
version: '2'
image_name: openai-api-verification
apis:
- inference
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:}
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY}
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
api_key: ${env.GROQ_API_KEY}
- provider_id: openai
provider_type: remote::openai
config:
url: https://api.openai.com/v1
api_key: ${env.OPENAI_API_KEY:}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/openai}/faiss_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/openai/trace_store.db}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
- provider_id: wolfram-alpha
provider_type: remote::wolfram-alpha
config:
api_key: ${env.WOLFRAM_ALPHA_API_KEY:}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/openai}/registry.db
models:
- metadata: {}
model_id: together/meta-llama/Llama-3.3-70B-Instruct-Turbo
provider_id: together
provider_model_id: meta-llama/Llama-3.3-70B-Instruct-Turbo
model_type: llm
- metadata: {}
model_id: together/meta-llama/Llama-4-Scout-17B-16E-Instruct
provider_id: together
provider_model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
model_type: llm
- metadata: {}
model_id: together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8
provider_id: together
provider_model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8
model_type: llm
- metadata: {}
model_id: fireworks/llama-v3p3-70b-instruct
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct
model_type: llm
- metadata: {}
model_id: fireworks/llama4-scout-instruct-basic
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama4-scout-instruct-basic
model_type: llm
- metadata: {}
model_id: fireworks/llama4-maverick-instruct-basic
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama4-maverick-instruct-basic
model_type: llm
- metadata: {}
model_id: groq/llama-3.3-70b-versatile
provider_id: groq
provider_model_id: groq/llama-3.3-70b-versatile
model_type: llm
- metadata: {}
model_id: groq/llama-4-scout-17b-16e-instruct
provider_id: groq
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
model_type: llm
- metadata: {}
model_id: groq/llama-4-maverick-17b-128e-instruct
provider_id: groq
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
model_type: llm
- metadata: {}
model_id: openai/gpt-4o
provider_id: openai
provider_model_id: openai/gpt-4o
model_type: llm
- metadata: {}
model_id: openai/gpt-4o-mini
provider_id: openai
provider_model_id: openai/gpt-4o-mini
model_type: llm
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter
- toolgroup_id: builtin::wolfram_alpha
provider_id: wolfram-alpha
server:
port: 8321

View file

@ -99,6 +99,9 @@ def model_mapping(provider, providers_model_mapping):
@pytest.fixture
def openai_client(base_url, api_key):
# Simplify running against a local Llama Stack
if "localhost" in base_url and not api_key:
api_key = "empty"
return OpenAI(
base_url=base_url,
api_key=api_key,

View file

@ -131,3 +131,221 @@ test_tool_calling:
type: object
type: function
output: get_weather_tool_call
test_chat_multi_turn_tool_calling:
test_name: test_chat_multi_turn_tool_calling
test_params:
case:
- case_id: "text_then_weather_tool"
input:
messages:
- - role: user
content: "What's the name of the Sun in latin?"
- - role: user
content: "What's the weather like in San Francisco?"
tools:
- function:
description: Get the current weather
name: get_weather
parameters:
type: object
properties:
location:
description: "The city and state (both required), e.g. San Francisco, CA."
type: string
required: ["location"]
type: function
tool_responses:
- response: "{'response': '70 degrees and foggy'}"
expected:
- num_tool_calls: 0
answer: ["sol"]
- num_tool_calls: 1
tool_name: get_weather
tool_arguments:
location: "San Francisco, CA"
- num_tool_calls: 0
answer: ["foggy", "70 degrees"]
- case_id: "weather_tool_then_text"
input:
messages:
- - role: user
content: "What's the weather like in San Francisco?"
tools:
- function:
description: Get the current weather
name: get_weather
parameters:
type: object
properties:
location:
description: "The city and state (both required), e.g. San Francisco, CA."
type: string
required: ["location"]
type: function
tool_responses:
- response: "{'response': '70 degrees and foggy'}"
expected:
- num_tool_calls: 1
tool_name: get_weather
tool_arguments:
location: "San Francisco, CA"
- num_tool_calls: 0
answer: ["foggy", "70 degrees"]
- case_id: "add_product_tool"
input:
messages:
- - role: user
content: "Please add a new product with name 'Widget', price 19.99, in stock, and tags ['new', 'sale'] and give me the product id."
tools:
- function:
description: Add a new product
name: addProduct
parameters:
type: object
properties:
name:
description: "Name of the product"
type: string
price:
description: "Price of the product"
type: number
inStock:
description: "Availability status of the product."
type: boolean
tags:
description: "List of product tags"
type: array
items:
type: string
required: ["name", "price", "inStock"]
type: function
tool_responses:
- response: "{'response': 'Successfully added product with id: 123'}"
expected:
- num_tool_calls: 1
tool_name: addProduct
tool_arguments:
name: "Widget"
price: 19.99
inStock: true
tags:
- "new"
- "sale"
- num_tool_calls: 0
answer: ["123", "product id: 123"]
- case_id: "get_then_create_event_tool"
input:
messages:
- - role: system
content: "Todays date is 2025-03-01."
- role: user
content: "Do i have any meetings on March 3rd at 10 am? Yes or no?"
- - role: user
content: "Alright then, Create an event named 'Team Building', scheduled for that time same time, in the 'Main Conference Room' and add Alice, Bob, Charlie to it. Give me the created event id."
tools:
- function:
description: Create a new event
name: create_event
parameters:
type: object
properties:
name:
description: "Name of the event"
type: string
date:
description: "Date of the event in ISO format"
type: string
time:
description: "Event Time (HH:MM)"
type: string
location:
description: "Location of the event"
type: string
participants:
description: "List of participant names"
type: array
items:
type: string
required: ["name", "date", "time", "location", "participants"]
type: function
- function:
description: Get an event by date and time
name: get_event
parameters:
type: object
properties:
date:
description: "Date of the event in ISO format"
type: string
time:
description: "Event Time (HH:MM)"
type: string
required: ["date", "time"]
type: function
tool_responses:
- response: "{'response': 'No events found for 2025-03-03 at 10:00'}"
- response: "{'response': 'Successfully created new event with id: e_123'}"
expected:
- num_tool_calls: 1
tool_name: get_event
tool_arguments:
date: "2025-03-03"
time: "10:00"
- num_tool_calls: 0
answer: ["no", "no events found", "no meetings"]
- num_tool_calls: 1
tool_name: create_event
tool_arguments:
name: "Team Building"
date: "2025-03-03"
time: "10:00"
location: "Main Conference Room"
participants:
- "Alice"
- "Bob"
- "Charlie"
- num_tool_calls: 0
answer: ["e_123", "event id: e_123"]
- case_id: "compare_monthly_expense_tool"
input:
messages:
- - role: system
content: "Todays date is 2025-03-01."
- role: user
content: "what was my monthly expense in Jan of this year?"
- - role: user
content: "Was it less than Feb of last year? Only answer with yes or no."
tools:
- function:
description: Get monthly expense summary
name: getMonthlyExpenseSummary
parameters:
type: object
properties:
month:
description: "Month of the year (1-12)"
type: integer
year:
description: "Year"
type: integer
required: ["month", "year"]
type: function
tool_responses:
- response: "{'response': 'Total expenses for January 2025: $1000'}"
- response: "{'response': 'Total expenses for February 2024: $2000'}"
expected:
- num_tool_calls: 1
tool_name: getMonthlyExpenseSummary
tool_arguments:
month: 1
year: 2025
- num_tool_calls: 0
answer: ["1000", "$1,000", "1,000"]
- num_tool_calls: 1
tool_name: getMonthlyExpenseSummary
tool_arguments:
month: 2
year: 2024
- num_tool_calls: 0
answer: ["yes"]

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import copy
import json
import re
from typing import Any
@ -243,43 +244,294 @@ def test_chat_streaming_tool_calling(request, openai_client, model, provider, ve
stream=True,
)
# Accumulate partial tool_calls here
tool_calls_buffer = {}
current_id = None
# Process streaming chunks
for chunk in stream:
choice = chunk.choices[0]
delta = choice.delta
if delta.tool_calls is None:
continue
for tool_call_delta in delta.tool_calls:
if tool_call_delta.id:
current_id = tool_call_delta.id
call_id = current_id
func_delta = tool_call_delta.function
if call_id not in tool_calls_buffer:
tool_calls_buffer[call_id] = {
"id": call_id,
"type": tool_call_delta.type,
"name": func_delta.name,
"arguments": "",
}
if func_delta.arguments:
tool_calls_buffer[call_id]["arguments"] += func_delta.arguments
_, tool_calls_buffer = _accumulate_streaming_tool_calls(stream)
assert len(tool_calls_buffer) == 1
for call in tool_calls_buffer.values():
for call in tool_calls_buffer:
assert len(call["id"]) > 0
assert call["name"] == "get_weather"
function = call["function"]
assert function["name"] == "get_weather"
args_dict = json.loads(call["arguments"])
args_dict = json.loads(function["arguments"])
assert "san francisco" in args_dict["location"].lower()
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], # Reusing existing case for now
ids=case_id_generator,
)
def test_chat_non_streaming_tool_choice_required(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
response = openai_client.chat.completions.create(
model=model,
messages=case["input"]["messages"],
tools=case["input"]["tools"],
tool_choice="required", # Force tool call
stream=False,
)
print(response)
assert response.choices[0].message.role == "assistant"
assert len(response.choices[0].message.tool_calls) > 0, "Expected tool call when tool_choice='required'"
expected_tool_name = case["input"]["tools"][0]["function"]["name"]
assert response.choices[0].message.tool_calls[0].function.name == expected_tool_name
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], # Reusing existing case for now
ids=case_id_generator,
)
def test_chat_streaming_tool_choice_required(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
stream = openai_client.chat.completions.create(
model=model,
messages=case["input"]["messages"],
tools=case["input"]["tools"],
tool_choice="required", # Force tool call
stream=True,
)
_, tool_calls_buffer = _accumulate_streaming_tool_calls(stream)
assert len(tool_calls_buffer) > 0, "Expected tool call when tool_choice='required'"
expected_tool_name = case["input"]["tools"][0]["function"]["name"]
assert any(call["function"]["name"] == expected_tool_name for call in tool_calls_buffer), (
f"Expected tool call '{expected_tool_name}' not found in stream"
)
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], # Reusing existing case for now
ids=case_id_generator,
)
def test_chat_non_streaming_tool_choice_none(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
response = openai_client.chat.completions.create(
model=model,
messages=case["input"]["messages"],
tools=case["input"]["tools"],
tool_choice="none",
stream=False,
)
assert response.choices[0].message.role == "assistant"
assert response.choices[0].message.tool_calls is None, "Expected no tool calls when tool_choice='none'"
assert response.choices[0].message.content is not None, "Expected content when tool_choice='none'"
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], # Reusing existing case for now
ids=case_id_generator,
)
def test_chat_streaming_tool_choice_none(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
stream = openai_client.chat.completions.create(
model=model,
messages=case["input"]["messages"],
tools=case["input"]["tools"],
tool_choice="none",
stream=True,
)
content = ""
for chunk in stream:
delta = chunk.choices[0].delta
if delta.content:
content += delta.content
assert not delta.tool_calls, "Expected no tool call chunks when tool_choice='none'"
assert len(content) > 0, "Expected content when tool_choice='none'"
@pytest.mark.parametrize(
"case",
chat_completion_test_cases.get("test_chat_multi_turn_tool_calling", {}).get("test_params", {}).get("case", []),
ids=case_id_generator,
)
def test_chat_non_streaming_multi_turn_tool_calling(request, openai_client, model, provider, verification_config, case):
"""
Test cases for multi-turn tool calling.
Tool calls are asserted.
Tool responses are provided in the test case.
Final response is asserted.
"""
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
# Create a copy of the messages list to avoid modifying the original
messages = []
tools = case["input"]["tools"]
# Use deepcopy to prevent modification across runs/parametrization
expected_results = copy.deepcopy(case["expected"])
tool_responses = copy.deepcopy(case.get("tool_responses", []))
input_messages_turns = copy.deepcopy(case["input"]["messages"])
# keep going until either
# 1. we have messages to test in multi-turn
# 2. no messages but last message is tool response
while len(input_messages_turns) > 0 or (len(messages) > 0 and messages[-1]["role"] == "tool"):
# do not take new messages if last message is tool response
if len(messages) == 0 or messages[-1]["role"] != "tool":
new_messages = input_messages_turns.pop(0)
# Ensure new_messages is a list of message objects
if isinstance(new_messages, list):
messages.extend(new_messages)
else:
# If it's a single message object, add it directly
messages.append(new_messages)
# --- API Call ---
response = openai_client.chat.completions.create(
model=model,
messages=messages,
tools=tools,
stream=False,
)
# --- Process Response ---
assistant_message = response.choices[0].message
messages.append(assistant_message.model_dump(exclude_unset=True))
assert assistant_message.role == "assistant"
# Get the expected result data
expected = expected_results.pop(0)
num_tool_calls = expected["num_tool_calls"]
# --- Assertions based on expected result ---
assert len(assistant_message.tool_calls or []) == num_tool_calls, (
f"Expected {num_tool_calls} tool calls, but got {len(assistant_message.tool_calls or [])}"
)
if num_tool_calls > 0:
tool_call = assistant_message.tool_calls[0]
assert tool_call.function.name == expected["tool_name"], (
f"Expected tool '{expected['tool_name']}', got '{tool_call.function.name}'"
)
# Parse the JSON string arguments before comparing
actual_arguments = json.loads(tool_call.function.arguments)
assert actual_arguments == expected["tool_arguments"], (
f"Expected arguments '{expected['tool_arguments']}', got '{actual_arguments}'"
)
# Prepare and append the tool response for the next turn
tool_response = tool_responses.pop(0)
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": tool_response["response"],
}
)
else:
assert assistant_message.content is not None, "Expected content, but none received."
expected_answers = expected["answer"] # This is now a list
content_lower = assistant_message.content.lower()
assert any(ans.lower() in content_lower for ans in expected_answers), (
f"Expected one of {expected_answers} in content, but got: '{assistant_message.content}'"
)
@pytest.mark.parametrize(
"case",
chat_completion_test_cases.get("test_chat_multi_turn_tool_calling", {}).get("test_params", {}).get("case", []),
ids=case_id_generator,
)
def test_chat_streaming_multi_turn_tool_calling(request, openai_client, model, provider, verification_config, case):
""" """
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
messages = []
tools = case["input"]["tools"]
expected_results = copy.deepcopy(case["expected"])
tool_responses = copy.deepcopy(case.get("tool_responses", []))
input_messages_turns = copy.deepcopy(case["input"]["messages"])
while len(input_messages_turns) > 0 or (len(messages) > 0 and messages[-1]["role"] == "tool"):
if len(messages) == 0 or messages[-1]["role"] != "tool":
new_messages = input_messages_turns.pop(0)
if isinstance(new_messages, list):
messages.extend(new_messages)
else:
messages.append(new_messages)
# --- API Call (Streaming) ---
stream = openai_client.chat.completions.create(
model=model,
messages=messages,
tools=tools,
stream=True,
)
# --- Process Stream ---
accumulated_content, accumulated_tool_calls = _accumulate_streaming_tool_calls(stream)
# --- Construct Assistant Message for History ---
assistant_message_dict = {"role": "assistant"}
if accumulated_content:
assistant_message_dict["content"] = accumulated_content
if accumulated_tool_calls:
assistant_message_dict["tool_calls"] = accumulated_tool_calls
messages.append(assistant_message_dict)
# --- Assertions ---
expected = expected_results.pop(0)
num_tool_calls = expected["num_tool_calls"]
assert len(accumulated_tool_calls or []) == num_tool_calls, (
f"Expected {num_tool_calls} tool calls, but got {len(accumulated_tool_calls or [])}"
)
if num_tool_calls > 0:
# Use the first accumulated tool call for assertion
tool_call = accumulated_tool_calls[0]
assert tool_call["function"]["name"] == expected["tool_name"], (
f"Expected tool '{expected['tool_name']}', got '{tool_call['function']['name']}'"
)
# Parse the accumulated arguments string for comparison
actual_arguments = json.loads(tool_call["function"]["arguments"])
assert actual_arguments == expected["tool_arguments"], (
f"Expected arguments '{expected['tool_arguments']}', got '{actual_arguments}'"
)
# Prepare and append the tool response for the next turn
tool_response = tool_responses.pop(0)
messages.append(
{
"role": "tool",
"tool_call_id": tool_call["id"],
"content": tool_response["response"],
}
)
else:
assert accumulated_content is not None and accumulated_content != "", "Expected content, but none received."
expected_answers = expected["answer"]
content_lower = accumulated_content.lower()
assert any(ans.lower() in content_lower for ans in expected_answers), (
f"Expected one of {expected_answers} in content, but got: '{accumulated_content}'"
)
# --- Helper functions (structured output validation) ---
@ -324,3 +576,47 @@ def validate_structured_output(maybe_json_content: str, schema_name: str) -> Non
assert len(structured_output.participants) == 2
elif schema_name == "valid_math_reasoning":
assert len(structured_output.final_answer) > 0
def _accumulate_streaming_tool_calls(stream):
"""Accumulates tool calls and content from a streaming ChatCompletion response."""
tool_calls_buffer = {}
current_id = None
full_content = "" # Initialize content accumulator
# Process streaming chunks
for chunk in stream:
choice = chunk.choices[0]
delta = choice.delta
# Accumulate content
if delta.content:
full_content += delta.content
if delta.tool_calls is None:
continue
for tool_call_delta in delta.tool_calls:
if tool_call_delta.id:
current_id = tool_call_delta.id
call_id = current_id
# Skip if no ID seen yet for this tool call delta
if not call_id:
continue
func_delta = tool_call_delta.function
if call_id not in tool_calls_buffer:
tool_calls_buffer[call_id] = {
"id": call_id,
"type": "function", # Assume function type
"function": {"name": None, "arguments": ""}, # Nested structure
}
# Accumulate name and arguments into the nested function dict
if func_delta:
if func_delta.name:
tool_calls_buffer[call_id]["function"]["name"] = func_delta.name
if func_delta.arguments:
tool_calls_buffer[call_id]["function"]["arguments"] += func_delta.arguments
# Return content and tool calls as a list
return full_content, list(tool_calls_buffer.values())

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load diff

File diff suppressed because one or more lines are too long