mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 01:03:59 +00:00
Merge branch 'main' into register_custom_model
This commit is contained in:
commit
afb792b9c1
69 changed files with 8875 additions and 890 deletions
36
.github/workflows/integration-tests.yml
vendored
36
.github/workflows/integration-tests.yml
vendored
|
@ -34,22 +34,20 @@ jobs:
|
|||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@22695119d769bdb6f7032ad67b9bca0ef8c4a174 # v5.4.0
|
||||
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Install Ollama
|
||||
- name: Install and start Ollama
|
||||
run: |
|
||||
# the ollama installer also starts the ollama service
|
||||
curl -fsSL https://ollama.com/install.sh | sh
|
||||
|
||||
- name: Pull Ollama image
|
||||
run: |
|
||||
# TODO: cache the model. OLLAMA_MODELS defaults to ~ollama/.ollama/models.
|
||||
ollama pull llama3.2:3b-instruct-fp16
|
||||
|
||||
- name: Start Ollama in background
|
||||
run: |
|
||||
nohup ollama run llama3.2:3b-instruct-fp16 > ollama.log 2>&1 &
|
||||
|
||||
- name: Set Up Environment and Install Dependencies
|
||||
run: |
|
||||
uv sync --extra dev --extra test
|
||||
|
@ -61,21 +59,6 @@ jobs:
|
|||
uv pip install -e .
|
||||
llama stack build --template ollama --image-type venv
|
||||
|
||||
- name: Wait for Ollama to start
|
||||
run: |
|
||||
echo "Waiting for Ollama..."
|
||||
for i in {1..30}; do
|
||||
if curl -s http://localhost:11434 | grep -q "Ollama is running"; then
|
||||
echo "Ollama is running!"
|
||||
exit 0
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
echo "Ollama failed to start"
|
||||
ollama ps
|
||||
ollama.log
|
||||
exit 1
|
||||
|
||||
- name: Start Llama Stack server in background
|
||||
if: matrix.client-type == 'http'
|
||||
env:
|
||||
|
@ -99,6 +82,17 @@ jobs:
|
|||
cat server.log
|
||||
exit 1
|
||||
|
||||
- name: Verify Ollama status is OK
|
||||
if: matrix.client-type == 'http'
|
||||
run: |
|
||||
echo "Verifying Ollama status..."
|
||||
ollama_status=$(curl -s -L http://127.0.0.1:8321/v1/providers/ollama|jq --raw-output .health.status)
|
||||
echo "Ollama status: $ollama_status"
|
||||
if [ "$ollama_status" != "OK" ]; then
|
||||
echo "Ollama health check failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Run Integration Tests
|
||||
env:
|
||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||
|
|
9
.github/workflows/pre-commit.yml
vendored
9
.github/workflows/pre-commit.yml
vendored
|
@ -31,3 +31,12 @@ jobs:
|
|||
- name: Verify if there are any diff files after pre-commit
|
||||
run: |
|
||||
git diff --exit-code || (echo "There are uncommitted changes, run pre-commit locally and commit again" && exit 1)
|
||||
|
||||
- name: Verify if there are any new files after pre-commit
|
||||
run: |
|
||||
unstaged_files=$(git ls-files --others --exclude-standard)
|
||||
if [ -n "$unstaged_files" ]; then
|
||||
echo "There are uncommitted new files, run pre-commit locally and commit again"
|
||||
echo "$unstaged_files"
|
||||
exit 1
|
||||
fi
|
||||
|
|
28
.github/workflows/providers-build.yml
vendored
28
.github/workflows/providers-build.yml
vendored
|
@ -56,7 +56,7 @@ jobs:
|
|||
python-version: '3.10'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@22695119d769bdb6f7032ad67b9bca0ef8c4a174 # v5.4.0
|
||||
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
|
@ -81,3 +81,29 @@ jobs:
|
|||
run: |
|
||||
source test/bin/activate
|
||||
uv pip list
|
||||
|
||||
build-single-provider:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Install LlamaStack
|
||||
run: |
|
||||
uv venv
|
||||
source .venv/bin/activate
|
||||
uv pip install -e .
|
||||
|
||||
- name: Build a single provider
|
||||
run: |
|
||||
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --image-type venv --image-name test --providers inference=remote::ollama
|
||||
|
|
2
.github/workflows/unit-tests.yml
vendored
2
.github/workflows/unit-tests.yml
vendored
|
@ -38,7 +38,7 @@ jobs:
|
|||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
|
||||
- uses: astral-sh/setup-uv@22695119d769bdb6f7032ad67b9bca0ef8c4a174 # v5.4.0
|
||||
- uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
enable-cache: false
|
||||
|
|
2
.github/workflows/update-readthedocs.yml
vendored
2
.github/workflows/update-readthedocs.yml
vendored
|
@ -41,7 +41,7 @@ jobs:
|
|||
python-version: '3.11'
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@22695119d769bdb6f7032ad67b9bca0ef8c4a174 # v5.4.0
|
||||
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
|
||||
|
||||
- name: Sync with uv
|
||||
run: uv sync --extra docs
|
||||
|
|
10
README.md
10
README.md
|
@ -9,15 +9,16 @@
|
|||
|
||||
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb)
|
||||
|
||||
|
||||
### ✨🎉 Llama 4 Support 🎉✨
|
||||
We released [Version 0.2.0](https://github.com/meta-llama/llama-stack/releases/tag/v0.2.0) with support for the Llama 4 herd of models released by Meta.
|
||||
|
||||
You can now run Llama 4 models on Llama Stack.
|
||||
<details>
|
||||
|
||||
<summary>👋 Click here to see how to run Llama 4 models on Llama Stack </summary>
|
||||
|
||||
\
|
||||
*Note you need 8xH100 GPU-host to run these models*
|
||||
|
||||
|
||||
```bash
|
||||
pip install -U llama_stack
|
||||
|
||||
|
@ -67,6 +68,9 @@ print(f"Assistant> {response.completion_message.content}")
|
|||
As more providers start supporting Llama 4, you can use them in Llama Stack as well. We are adding to the list. Stay tuned!
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
### Overview
|
||||
|
||||
Llama Stack standardizes the core building blocks that simplify AI application development. It codifies best practices across the Llama ecosystem. More specifically, it provides
|
||||
|
|
446
docs/_static/llama-stack-spec.html
vendored
446
docs/_static/llama-stack-spec.html
vendored
|
@ -3096,11 +3096,18 @@
|
|||
"post": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"description": "Response from an OpenAI-compatible chat completion request. **OR** Chunk from a streaming response to an OpenAI-compatible chat completion request.",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/OpenAIChatCompletion"
|
||||
"oneOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIChatCompletion"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIChatCompletionChunk"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -7889,7 +7896,13 @@
|
|||
"type": "object",
|
||||
"properties": {
|
||||
"status": {
|
||||
"type": "string"
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"OK",
|
||||
"Error",
|
||||
"Not Implemented"
|
||||
],
|
||||
"title": "HealthStatus"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
@ -8084,6 +8097,31 @@
|
|||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"health": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "null"
|
||||
},
|
||||
{
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"type": "number"
|
||||
},
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "array"
|
||||
},
|
||||
{
|
||||
"type": "object"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
@ -8091,7 +8129,8 @@
|
|||
"api",
|
||||
"provider_id",
|
||||
"provider_type",
|
||||
"config"
|
||||
"config",
|
||||
"health"
|
||||
],
|
||||
"title": "ProviderInfo"
|
||||
},
|
||||
|
@ -8825,7 +8864,17 @@
|
|||
"description": "Must be \"assistant\" to identify this as the model's response"
|
||||
},
|
||||
"content": {
|
||||
"$ref": "#/components/schemas/InterleavedContent",
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
|
||||
}
|
||||
}
|
||||
],
|
||||
"description": "The content of the model's response"
|
||||
},
|
||||
"name": {
|
||||
|
@ -8835,9 +8884,9 @@
|
|||
"tool_calls": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/ToolCall"
|
||||
"$ref": "#/components/schemas/OpenAIChatCompletionToolCall"
|
||||
},
|
||||
"description": "List of tool calls. Each tool call is a ToolCall object."
|
||||
"description": "List of tool calls. Each tool call is an OpenAIChatCompletionToolCall object."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
@ -8848,6 +8897,98 @@
|
|||
"title": "OpenAIAssistantMessageParam",
|
||||
"description": "A message containing the model's (assistant) response in an OpenAI-compatible chat completion request."
|
||||
},
|
||||
"OpenAIChatCompletionContentPartImageParam": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "image_url",
|
||||
"default": "image_url"
|
||||
},
|
||||
"image_url": {
|
||||
"$ref": "#/components/schemas/OpenAIImageURL"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"image_url"
|
||||
],
|
||||
"title": "OpenAIChatCompletionContentPartImageParam"
|
||||
},
|
||||
"OpenAIChatCompletionContentPartParam": {
|
||||
"oneOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
|
||||
}
|
||||
],
|
||||
"discriminator": {
|
||||
"propertyName": "type",
|
||||
"mapping": {
|
||||
"text": "#/components/schemas/OpenAIChatCompletionContentPartTextParam",
|
||||
"image_url": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
|
||||
}
|
||||
}
|
||||
},
|
||||
"OpenAIChatCompletionContentPartTextParam": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "text",
|
||||
"default": "text"
|
||||
},
|
||||
"text": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"text"
|
||||
],
|
||||
"title": "OpenAIChatCompletionContentPartTextParam"
|
||||
},
|
||||
"OpenAIChatCompletionToolCall": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"index": {
|
||||
"type": "integer"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "function",
|
||||
"default": "function"
|
||||
},
|
||||
"function": {
|
||||
"$ref": "#/components/schemas/OpenAIChatCompletionToolCallFunction"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type"
|
||||
],
|
||||
"title": "OpenAIChatCompletionToolCall"
|
||||
},
|
||||
"OpenAIChatCompletionToolCallFunction": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"arguments": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"title": "OpenAIChatCompletionToolCallFunction"
|
||||
},
|
||||
"OpenAIDeveloperMessageParam": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -8858,7 +8999,17 @@
|
|||
"description": "Must be \"developer\" to identify this as a developer message"
|
||||
},
|
||||
"content": {
|
||||
"$ref": "#/components/schemas/InterleavedContent",
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
|
||||
}
|
||||
}
|
||||
],
|
||||
"description": "The content of the developer message"
|
||||
},
|
||||
"name": {
|
||||
|
@ -8874,6 +9025,66 @@
|
|||
"title": "OpenAIDeveloperMessageParam",
|
||||
"description": "A message from the developer in an OpenAI-compatible chat completion request."
|
||||
},
|
||||
"OpenAIImageURL": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string"
|
||||
},
|
||||
"detail": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"url"
|
||||
],
|
||||
"title": "OpenAIImageURL"
|
||||
},
|
||||
"OpenAIJSONSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"strict": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "null"
|
||||
},
|
||||
{
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"type": "number"
|
||||
},
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "array"
|
||||
},
|
||||
{
|
||||
"type": "object"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"name"
|
||||
],
|
||||
"title": "OpenAIJSONSchema"
|
||||
},
|
||||
"OpenAIMessageParam": {
|
||||
"oneOf": [
|
||||
{
|
||||
|
@ -8903,6 +9114,76 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"OpenAIResponseFormatJSONObject": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "json_object",
|
||||
"default": "json_object"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type"
|
||||
],
|
||||
"title": "OpenAIResponseFormatJSONObject"
|
||||
},
|
||||
"OpenAIResponseFormatJSONSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "json_schema",
|
||||
"default": "json_schema"
|
||||
},
|
||||
"json_schema": {
|
||||
"$ref": "#/components/schemas/OpenAIJSONSchema"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"json_schema"
|
||||
],
|
||||
"title": "OpenAIResponseFormatJSONSchema"
|
||||
},
|
||||
"OpenAIResponseFormatParam": {
|
||||
"oneOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseFormatText"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseFormatJSONSchema"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseFormatJSONObject"
|
||||
}
|
||||
],
|
||||
"discriminator": {
|
||||
"propertyName": "type",
|
||||
"mapping": {
|
||||
"text": "#/components/schemas/OpenAIResponseFormatText",
|
||||
"json_schema": "#/components/schemas/OpenAIResponseFormatJSONSchema",
|
||||
"json_object": "#/components/schemas/OpenAIResponseFormatJSONObject"
|
||||
}
|
||||
}
|
||||
},
|
||||
"OpenAIResponseFormatText": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "text",
|
||||
"default": "text"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type"
|
||||
],
|
||||
"title": "OpenAIResponseFormatText"
|
||||
},
|
||||
"OpenAISystemMessageParam": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -8913,7 +9194,17 @@
|
|||
"description": "Must be \"system\" to identify this as a system message"
|
||||
},
|
||||
"content": {
|
||||
"$ref": "#/components/schemas/InterleavedContent",
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
|
||||
}
|
||||
}
|
||||
],
|
||||
"description": "The content of the \"system prompt\". If multiple system messages are provided, they are concatenated. The underlying Llama Stack code may also add other system messages (for example, for formatting tool definitions)."
|
||||
},
|
||||
"name": {
|
||||
|
@ -8943,7 +9234,17 @@
|
|||
"description": "Unique identifier for the tool call this response is for"
|
||||
},
|
||||
"content": {
|
||||
"$ref": "#/components/schemas/InterleavedContent",
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
|
||||
}
|
||||
}
|
||||
],
|
||||
"description": "The response content from the tool"
|
||||
}
|
||||
},
|
||||
|
@ -8966,7 +9267,17 @@
|
|||
"description": "Must be \"user\" to identify this as a user message"
|
||||
},
|
||||
"content": {
|
||||
"$ref": "#/components/schemas/InterleavedContent",
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
|
||||
}
|
||||
}
|
||||
],
|
||||
"description": "The content of the message, which can include text and other media"
|
||||
},
|
||||
"name": {
|
||||
|
@ -9094,10 +9405,7 @@
|
|||
"description": "(Optional) The penalty for repeated tokens"
|
||||
},
|
||||
"response_format": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
},
|
||||
"$ref": "#/components/schemas/OpenAIResponseFormatParam",
|
||||
"description": "(Optional) The response format to use"
|
||||
},
|
||||
"seed": {
|
||||
|
@ -9274,6 +9582,46 @@
|
|||
"title": "OpenAIChatCompletion",
|
||||
"description": "Response from an OpenAI-compatible chat completion request."
|
||||
},
|
||||
"OpenAIChatCompletionChunk": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"description": "The ID of the chat completion"
|
||||
},
|
||||
"choices": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/OpenAIChunkChoice"
|
||||
},
|
||||
"description": "List of choices"
|
||||
},
|
||||
"object": {
|
||||
"type": "string",
|
||||
"const": "chat.completion.chunk",
|
||||
"default": "chat.completion.chunk",
|
||||
"description": "The object type, which will be \"chat.completion.chunk\""
|
||||
},
|
||||
"created": {
|
||||
"type": "integer",
|
||||
"description": "The Unix timestamp in seconds when the chat completion was created"
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": "The model that was used to generate the chat completion"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"id",
|
||||
"choices",
|
||||
"object",
|
||||
"created",
|
||||
"model"
|
||||
],
|
||||
"title": "OpenAIChatCompletionChunk",
|
||||
"description": "Chunk from a streaming response to an OpenAI-compatible chat completion request."
|
||||
},
|
||||
"OpenAIChoice": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -9286,10 +9634,12 @@
|
|||
"description": "The reason the model stopped generating"
|
||||
},
|
||||
"index": {
|
||||
"type": "integer"
|
||||
"type": "integer",
|
||||
"description": "The index of the choice"
|
||||
},
|
||||
"logprobs": {
|
||||
"$ref": "#/components/schemas/OpenAIChoiceLogprobs"
|
||||
"$ref": "#/components/schemas/OpenAIChoiceLogprobs",
|
||||
"description": "(Optional) The log probabilities for the tokens in the message"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
@ -9301,6 +9651,33 @@
|
|||
"title": "OpenAIChoice",
|
||||
"description": "A choice from an OpenAI-compatible chat completion response."
|
||||
},
|
||||
"OpenAIChoiceDelta": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "(Optional) The content of the delta"
|
||||
},
|
||||
"refusal": {
|
||||
"type": "string",
|
||||
"description": "(Optional) The refusal of the delta"
|
||||
},
|
||||
"role": {
|
||||
"type": "string",
|
||||
"description": "(Optional) The role of the delta"
|
||||
},
|
||||
"tool_calls": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/OpenAIChatCompletionToolCall"
|
||||
},
|
||||
"description": "(Optional) The tool calls of the delta"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"title": "OpenAIChoiceDelta",
|
||||
"description": "A delta from an OpenAI-compatible chat completion streaming response."
|
||||
},
|
||||
"OpenAIChoiceLogprobs": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -9308,19 +9685,50 @@
|
|||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/OpenAITokenLogProb"
|
||||
}
|
||||
},
|
||||
"description": "(Optional) The log probabilities for the tokens in the message"
|
||||
},
|
||||
"refusal": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/OpenAITokenLogProb"
|
||||
}
|
||||
},
|
||||
"description": "(Optional) The log probabilities for the tokens in the message"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"title": "OpenAIChoiceLogprobs",
|
||||
"description": "The log probabilities for the tokens in the message from an OpenAI-compatible chat completion response."
|
||||
},
|
||||
"OpenAIChunkChoice": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"delta": {
|
||||
"$ref": "#/components/schemas/OpenAIChoiceDelta",
|
||||
"description": "The delta from the chunk"
|
||||
},
|
||||
"finish_reason": {
|
||||
"type": "string",
|
||||
"description": "The reason the model stopped generating"
|
||||
},
|
||||
"index": {
|
||||
"type": "integer",
|
||||
"description": "The index of the choice"
|
||||
},
|
||||
"logprobs": {
|
||||
"$ref": "#/components/schemas/OpenAIChoiceLogprobs",
|
||||
"description": "(Optional) The log probabilities for the tokens in the message"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"delta",
|
||||
"finish_reason",
|
||||
"index"
|
||||
],
|
||||
"title": "OpenAIChunkChoice",
|
||||
"description": "A chunk choice from an OpenAI-compatible chat completion streaming response."
|
||||
},
|
||||
"OpenAITokenLogProb": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
|
295
docs/_static/llama-stack-spec.yaml
vendored
295
docs/_static/llama-stack-spec.yaml
vendored
|
@ -2135,11 +2135,15 @@ paths:
|
|||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
description: >-
|
||||
Response from an OpenAI-compatible chat completion request. **OR** Chunk
|
||||
from a streaming response to an OpenAI-compatible chat completion request.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/OpenAIChatCompletion'
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/OpenAIChatCompletion'
|
||||
- $ref: '#/components/schemas/OpenAIChatCompletionChunk'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
|
@ -5463,6 +5467,11 @@ components:
|
|||
properties:
|
||||
status:
|
||||
type: string
|
||||
enum:
|
||||
- OK
|
||||
- Error
|
||||
- Not Implemented
|
||||
title: HealthStatus
|
||||
additionalProperties: false
|
||||
required:
|
||||
- status
|
||||
|
@ -5574,12 +5583,23 @@ components:
|
|||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
health:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- api
|
||||
- provider_id
|
||||
- provider_type
|
||||
- config
|
||||
- health
|
||||
title: ProviderInfo
|
||||
InvokeToolRequest:
|
||||
type: object
|
||||
|
@ -6057,7 +6077,11 @@ components:
|
|||
description: >-
|
||||
Must be "assistant" to identify this as the model's response
|
||||
content:
|
||||
$ref: '#/components/schemas/InterleavedContent'
|
||||
oneOf:
|
||||
- type: string
|
||||
- type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
|
||||
description: The content of the model's response
|
||||
name:
|
||||
type: string
|
||||
|
@ -6066,9 +6090,10 @@ components:
|
|||
tool_calls:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/ToolCall'
|
||||
$ref: '#/components/schemas/OpenAIChatCompletionToolCall'
|
||||
description: >-
|
||||
List of tool calls. Each tool call is a ToolCall object.
|
||||
List of tool calls. Each tool call is an OpenAIChatCompletionToolCall
|
||||
object.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- role
|
||||
|
@ -6077,6 +6102,70 @@ components:
|
|||
description: >-
|
||||
A message containing the model's (assistant) response in an OpenAI-compatible
|
||||
chat completion request.
|
||||
"OpenAIChatCompletionContentPartImageParam":
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: image_url
|
||||
default: image_url
|
||||
image_url:
|
||||
$ref: '#/components/schemas/OpenAIImageURL'
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- image_url
|
||||
title: >-
|
||||
OpenAIChatCompletionContentPartImageParam
|
||||
OpenAIChatCompletionContentPartParam:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
||||
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
text: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
||||
image_url: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
|
||||
OpenAIChatCompletionContentPartTextParam:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: text
|
||||
default: text
|
||||
text:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- text
|
||||
title: OpenAIChatCompletionContentPartTextParam
|
||||
OpenAIChatCompletionToolCall:
|
||||
type: object
|
||||
properties:
|
||||
index:
|
||||
type: integer
|
||||
id:
|
||||
type: string
|
||||
type:
|
||||
type: string
|
||||
const: function
|
||||
default: function
|
||||
function:
|
||||
$ref: '#/components/schemas/OpenAIChatCompletionToolCallFunction'
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
title: OpenAIChatCompletionToolCall
|
||||
OpenAIChatCompletionToolCallFunction:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
arguments:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
title: OpenAIChatCompletionToolCallFunction
|
||||
OpenAIDeveloperMessageParam:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -6087,7 +6176,11 @@ components:
|
|||
description: >-
|
||||
Must be "developer" to identify this as a developer message
|
||||
content:
|
||||
$ref: '#/components/schemas/InterleavedContent'
|
||||
oneOf:
|
||||
- type: string
|
||||
- type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
|
||||
description: The content of the developer message
|
||||
name:
|
||||
type: string
|
||||
|
@ -6100,6 +6193,40 @@ components:
|
|||
title: OpenAIDeveloperMessageParam
|
||||
description: >-
|
||||
A message from the developer in an OpenAI-compatible chat completion request.
|
||||
OpenAIImageURL:
|
||||
type: object
|
||||
properties:
|
||||
url:
|
||||
type: string
|
||||
detail:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
required:
|
||||
- url
|
||||
title: OpenAIImageURL
|
||||
OpenAIJSONSchema:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
description:
|
||||
type: string
|
||||
strict:
|
||||
type: boolean
|
||||
schema:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- name
|
||||
title: OpenAIJSONSchema
|
||||
OpenAIMessageParam:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/OpenAIUserMessageParam'
|
||||
|
@ -6115,6 +6242,53 @@ components:
|
|||
assistant: '#/components/schemas/OpenAIAssistantMessageParam'
|
||||
tool: '#/components/schemas/OpenAIToolMessageParam'
|
||||
developer: '#/components/schemas/OpenAIDeveloperMessageParam'
|
||||
OpenAIResponseFormatJSONObject:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: json_object
|
||||
default: json_object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
title: OpenAIResponseFormatJSONObject
|
||||
OpenAIResponseFormatJSONSchema:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: json_schema
|
||||
default: json_schema
|
||||
json_schema:
|
||||
$ref: '#/components/schemas/OpenAIJSONSchema'
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- json_schema
|
||||
title: OpenAIResponseFormatJSONSchema
|
||||
OpenAIResponseFormatParam:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/OpenAIResponseFormatText'
|
||||
- $ref: '#/components/schemas/OpenAIResponseFormatJSONSchema'
|
||||
- $ref: '#/components/schemas/OpenAIResponseFormatJSONObject'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
text: '#/components/schemas/OpenAIResponseFormatText'
|
||||
json_schema: '#/components/schemas/OpenAIResponseFormatJSONSchema'
|
||||
json_object: '#/components/schemas/OpenAIResponseFormatJSONObject'
|
||||
OpenAIResponseFormatText:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: text
|
||||
default: text
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
title: OpenAIResponseFormatText
|
||||
OpenAISystemMessageParam:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -6125,7 +6299,11 @@ components:
|
|||
description: >-
|
||||
Must be "system" to identify this as a system message
|
||||
content:
|
||||
$ref: '#/components/schemas/InterleavedContent'
|
||||
oneOf:
|
||||
- type: string
|
||||
- type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
|
||||
description: >-
|
||||
The content of the "system prompt". If multiple system messages are provided,
|
||||
they are concatenated. The underlying Llama Stack code may also add other
|
||||
|
@ -6155,7 +6333,11 @@ components:
|
|||
description: >-
|
||||
Unique identifier for the tool call this response is for
|
||||
content:
|
||||
$ref: '#/components/schemas/InterleavedContent'
|
||||
oneOf:
|
||||
- type: string
|
||||
- type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
|
||||
description: The response content from the tool
|
||||
additionalProperties: false
|
||||
required:
|
||||
|
@ -6176,7 +6358,11 @@ components:
|
|||
description: >-
|
||||
Must be "user" to identify this as a user message
|
||||
content:
|
||||
$ref: '#/components/schemas/InterleavedContent'
|
||||
oneOf:
|
||||
- type: string
|
||||
- type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
|
||||
description: >-
|
||||
The content of the message, which can include text and other media
|
||||
name:
|
||||
|
@ -6262,9 +6448,7 @@ components:
|
|||
description: >-
|
||||
(Optional) The penalty for repeated tokens
|
||||
response_format:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: string
|
||||
$ref: '#/components/schemas/OpenAIResponseFormatParam'
|
||||
description: (Optional) The response format to use
|
||||
seed:
|
||||
type: integer
|
||||
|
@ -6370,6 +6554,41 @@ components:
|
|||
title: OpenAIChatCompletion
|
||||
description: >-
|
||||
Response from an OpenAI-compatible chat completion request.
|
||||
OpenAIChatCompletionChunk:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
description: The ID of the chat completion
|
||||
choices:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/OpenAIChunkChoice'
|
||||
description: List of choices
|
||||
object:
|
||||
type: string
|
||||
const: chat.completion.chunk
|
||||
default: chat.completion.chunk
|
||||
description: >-
|
||||
The object type, which will be "chat.completion.chunk"
|
||||
created:
|
||||
type: integer
|
||||
description: >-
|
||||
The Unix timestamp in seconds when the chat completion was created
|
||||
model:
|
||||
type: string
|
||||
description: >-
|
||||
The model that was used to generate the chat completion
|
||||
additionalProperties: false
|
||||
required:
|
||||
- id
|
||||
- choices
|
||||
- object
|
||||
- created
|
||||
- model
|
||||
title: OpenAIChatCompletionChunk
|
||||
description: >-
|
||||
Chunk from a streaming response to an OpenAI-compatible chat completion request.
|
||||
OpenAIChoice:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -6381,8 +6600,11 @@ components:
|
|||
description: The reason the model stopped generating
|
||||
index:
|
||||
type: integer
|
||||
description: The index of the choice
|
||||
logprobs:
|
||||
$ref: '#/components/schemas/OpenAIChoiceLogprobs'
|
||||
description: >-
|
||||
(Optional) The log probabilities for the tokens in the message
|
||||
additionalProperties: false
|
||||
required:
|
||||
- message
|
||||
|
@ -6391,6 +6613,27 @@ components:
|
|||
title: OpenAIChoice
|
||||
description: >-
|
||||
A choice from an OpenAI-compatible chat completion response.
|
||||
OpenAIChoiceDelta:
|
||||
type: object
|
||||
properties:
|
||||
content:
|
||||
type: string
|
||||
description: (Optional) The content of the delta
|
||||
refusal:
|
||||
type: string
|
||||
description: (Optional) The refusal of the delta
|
||||
role:
|
||||
type: string
|
||||
description: (Optional) The role of the delta
|
||||
tool_calls:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/OpenAIChatCompletionToolCall'
|
||||
description: (Optional) The tool calls of the delta
|
||||
additionalProperties: false
|
||||
title: OpenAIChoiceDelta
|
||||
description: >-
|
||||
A delta from an OpenAI-compatible chat completion streaming response.
|
||||
OpenAIChoiceLogprobs:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -6398,15 +6641,43 @@ components:
|
|||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/OpenAITokenLogProb'
|
||||
description: >-
|
||||
(Optional) The log probabilities for the tokens in the message
|
||||
refusal:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/OpenAITokenLogProb'
|
||||
description: >-
|
||||
(Optional) The log probabilities for the tokens in the message
|
||||
additionalProperties: false
|
||||
title: OpenAIChoiceLogprobs
|
||||
description: >-
|
||||
The log probabilities for the tokens in the message from an OpenAI-compatible
|
||||
chat completion response.
|
||||
OpenAIChunkChoice:
|
||||
type: object
|
||||
properties:
|
||||
delta:
|
||||
$ref: '#/components/schemas/OpenAIChoiceDelta'
|
||||
description: The delta from the chunk
|
||||
finish_reason:
|
||||
type: string
|
||||
description: The reason the model stopped generating
|
||||
index:
|
||||
type: integer
|
||||
description: The index of the choice
|
||||
logprobs:
|
||||
$ref: '#/components/schemas/OpenAIChoiceLogprobs'
|
||||
description: >-
|
||||
(Optional) The log probabilities for the tokens in the message
|
||||
additionalProperties: false
|
||||
required:
|
||||
- delta
|
||||
- finish_reason
|
||||
- index
|
||||
title: OpenAIChunkChoice
|
||||
description: >-
|
||||
A chunk choice from an OpenAI-compatible chat completion streaming response.
|
||||
OpenAITokenLogProb:
|
||||
type: object
|
||||
properties:
|
||||
|
|
|
@ -24,7 +24,7 @@ The key files in the app are `ExampleLlamaStackLocalInference.kt`, `ExampleLlama
|
|||
Add the following dependency in your `build.gradle.kts` file:
|
||||
```
|
||||
dependencies {
|
||||
implementation("com.llama.llamastack:llama-stack-client-kotlin:0.1.4.2")
|
||||
implementation("com.llama.llamastack:llama-stack-client-kotlin:0.2.2")
|
||||
}
|
||||
```
|
||||
This will download jar files in your gradle cache in a directory like `~/.gradle/caches/modules-2/files-2.1/com.llama.llamastack/`
|
||||
|
@ -37,11 +37,7 @@ For local inferencing, it is required to include the ExecuTorch library into you
|
|||
|
||||
Include the ExecuTorch library by:
|
||||
1. Download the `download-prebuilt-et-lib.sh` script file from the [llama-stack-client-kotlin-client-local](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release/llama-stack-client-kotlin-client-local/download-prebuilt-et-lib.sh) directory to your local machine.
|
||||
2. Move the script to the top level of your Android app where the app directory resides:
|
||||
<p align="center">
|
||||
<img src="https://github.com/meta-llama/llama-stack-client-kotlin/blob/latest-release/doc/img/example_android_app_directory.png" style="width:300px">
|
||||
</p>
|
||||
|
||||
2. Move the script to the top level of your Android app where the `app` directory resides.
|
||||
3. Run `sh download-prebuilt-et-lib.sh` to create an `app/libs` directory and download the `executorch.aar` in that path. This generates an ExecuTorch library for the XNNPACK delegate.
|
||||
4. Add the `executorch.aar` dependency in your `build.gradle.kts` file:
|
||||
```
|
||||
|
@ -52,6 +48,8 @@ dependencies {
|
|||
}
|
||||
```
|
||||
|
||||
See other dependencies for the local RAG in Android app [README](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release/examples/android_app#quick-start).
|
||||
|
||||
## Llama Stack APIs in Your Android App
|
||||
Breaking down the demo app, this section will show the core pieces that are used to initialize and run inference with Llama Stack using the Kotlin library.
|
||||
|
||||
|
@ -60,7 +58,7 @@ Start a Llama Stack server on localhost. Here is an example of how you can do th
|
|||
```
|
||||
conda create -n stack-fireworks python=3.10
|
||||
conda activate stack-fireworks
|
||||
pip install --no-cache llama-stack==0.1.4
|
||||
pip install --no-cache llama-stack==0.2.2
|
||||
llama stack build --template fireworks --image-type conda
|
||||
export FIREWORKS_API_KEY=<SOME_KEY>
|
||||
llama stack run fireworks --port 5050
|
||||
|
|
|
@ -43,7 +43,9 @@ The following models are available by default:
|
|||
- `groq/llama-3.3-70b-versatile (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||
- `groq/llama-3.2-3b-preview (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||
- `groq/llama-4-scout-17b-16e-instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
|
||||
- `groq/meta-llama/llama-4-scout-17b-16e-instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
|
||||
- `groq/llama-4-maverick-17b-128e-instruct (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct)`
|
||||
- `groq/meta-llama/llama-4-maverick-17b-128e-instruct (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct)`
|
||||
|
||||
|
||||
### Prerequisite: API Keys
|
||||
|
|
|
@ -41,7 +41,7 @@ The following environment variables can be configured:
|
|||
|
||||
## Setting up vLLM server
|
||||
|
||||
In the following sections, we'll use either AMD and NVIDIA GPUs to serve as hardware accelerators for the vLLM
|
||||
In the following sections, we'll use AMD, NVIDIA or Intel GPUs to serve as hardware accelerators for the vLLM
|
||||
server, which acts as both the LLM inference provider and the safety provider. Note that vLLM also
|
||||
[supports many other hardware accelerators](https://docs.vllm.ai/en/latest/getting_started/installation.html) and
|
||||
that we only use GPUs here for demonstration purposes.
|
||||
|
@ -162,6 +162,55 @@ docker run \
|
|||
--port $SAFETY_PORT
|
||||
```
|
||||
|
||||
### Setting up vLLM server on Intel GPU
|
||||
|
||||
Refer to [vLLM Documentation for XPU](https://docs.vllm.ai/en/v0.8.2/getting_started/installation/gpu.html?device=xpu) to get a vLLM endpoint. In addition to vLLM side setup which guides towards installing vLLM from sources orself-building vLLM Docker container, Intel provides prebuilt vLLM container to use on systems with Intel GPUs supported by PyTorch XPU backend:
|
||||
- [intel/vllm](https://hub.docker.com/r/intel/vllm)
|
||||
|
||||
Here is a sample script to start a vLLM server locally via Docker using Intel provided container:
|
||||
|
||||
```bash
|
||||
export INFERENCE_PORT=8000
|
||||
export INFERENCE_MODEL=meta-llama/Llama-3.2-1B-Instruct
|
||||
export ZE_AFFINITY_MASK=0
|
||||
|
||||
docker run \
|
||||
--pull always \
|
||||
--device /dev/dri \
|
||||
-v /dev/dri/by-path:/dev/dri/by-path \
|
||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
|
||||
--env ZE_AFFINITY_MASK=$ZE_AFFINITY_MASK \
|
||||
-p $INFERENCE_PORT:$INFERENCE_PORT \
|
||||
--ipc=host \
|
||||
intel/vllm:xpu \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--model $INFERENCE_MODEL \
|
||||
--port $INFERENCE_PORT
|
||||
```
|
||||
|
||||
If you are using Llama Stack Safety / Shield APIs, then you will need to also run another instance of a vLLM with a corresponding safety model like `meta-llama/Llama-Guard-3-1B` using a script like:
|
||||
|
||||
```bash
|
||||
export SAFETY_PORT=8081
|
||||
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||
export ZE_AFFINITY_MASK=1
|
||||
|
||||
docker run \
|
||||
--pull always \
|
||||
--device /dev/dri \
|
||||
-v /dev/dri/by-path:/dev/dri/by-path \
|
||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
|
||||
--env ZE_AFFINITY_MASK=$ZE_AFFINITY_MASK \
|
||||
-p $SAFETY_PORT:$SAFETY_PORT \
|
||||
--ipc=host \
|
||||
intel/vllm:xpu \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--model $SAFETY_MODEL \
|
||||
--port $SAFETY_PORT
|
||||
```
|
||||
|
||||
## Running Llama Stack
|
||||
|
||||
Now you are ready to run Llama Stack with vLLM as the inference provider. You can do this via Conda (build code) or Docker which has a pre-built image.
|
||||
|
|
|
@ -18,7 +18,7 @@ from typing import (
|
|||
)
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import Annotated
|
||||
from typing_extensions import Annotated, TypedDict
|
||||
|
||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
||||
from llama_stack.apis.models import Model
|
||||
|
@ -442,6 +442,37 @@ class EmbeddingsResponse(BaseModel):
|
|||
embeddings: List[List[float]]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionContentPartTextParam(BaseModel):
|
||||
type: Literal["text"] = "text"
|
||||
text: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIImageURL(BaseModel):
|
||||
url: str
|
||||
detail: Optional[str] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionContentPartImageParam(BaseModel):
|
||||
type: Literal["image_url"] = "image_url"
|
||||
image_url: OpenAIImageURL
|
||||
|
||||
|
||||
OpenAIChatCompletionContentPartParam = Annotated[
|
||||
Union[
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
|
||||
|
||||
|
||||
OpenAIChatCompletionMessageContent = Union[str, List[OpenAIChatCompletionContentPartParam]]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIUserMessageParam(BaseModel):
|
||||
"""A message from the user in an OpenAI-compatible chat completion request.
|
||||
|
@ -452,7 +483,7 @@ class OpenAIUserMessageParam(BaseModel):
|
|||
"""
|
||||
|
||||
role: Literal["user"] = "user"
|
||||
content: InterleavedContent
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
|
@ -466,10 +497,24 @@ class OpenAISystemMessageParam(BaseModel):
|
|||
"""
|
||||
|
||||
role: Literal["system"] = "system"
|
||||
content: InterleavedContent
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionToolCallFunction(BaseModel):
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionToolCall(BaseModel):
|
||||
index: Optional[int] = None
|
||||
id: Optional[str] = None
|
||||
type: Literal["function"] = "function"
|
||||
function: Optional[OpenAIChatCompletionToolCallFunction] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIAssistantMessageParam(BaseModel):
|
||||
"""A message containing the model's (assistant) response in an OpenAI-compatible chat completion request.
|
||||
|
@ -477,13 +522,13 @@ class OpenAIAssistantMessageParam(BaseModel):
|
|||
:param role: Must be "assistant" to identify this as the model's response
|
||||
:param content: The content of the model's response
|
||||
:param name: (Optional) The name of the assistant message participant.
|
||||
:param tool_calls: List of tool calls. Each tool call is a ToolCall object.
|
||||
:param tool_calls: List of tool calls. Each tool call is an OpenAIChatCompletionToolCall object.
|
||||
"""
|
||||
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: InterleavedContent
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
name: Optional[str] = None
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
|
||||
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = Field(default_factory=list)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -497,7 +542,7 @@ class OpenAIToolMessageParam(BaseModel):
|
|||
|
||||
role: Literal["tool"] = "tool"
|
||||
tool_call_id: str
|
||||
content: InterleavedContent
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -510,7 +555,7 @@ class OpenAIDeveloperMessageParam(BaseModel):
|
|||
"""
|
||||
|
||||
role: Literal["developer"] = "developer"
|
||||
content: InterleavedContent
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
|
@ -527,6 +572,46 @@ OpenAIMessageParam = Annotated[
|
|||
register_schema(OpenAIMessageParam, name="OpenAIMessageParam")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseFormatText(BaseModel):
|
||||
type: Literal["text"] = "text"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIJSONSchema(TypedDict, total=False):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
strict: Optional[bool] = None
|
||||
|
||||
# Pydantic BaseModel cannot be used with a schema param, since it already
|
||||
# has one. And, we don't want to alias here because then have to handle
|
||||
# that alias when converting to OpenAI params. So, to support schema,
|
||||
# we use a TypedDict.
|
||||
schema: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseFormatJSONSchema(BaseModel):
|
||||
type: Literal["json_schema"] = "json_schema"
|
||||
json_schema: OpenAIJSONSchema
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseFormatJSONObject(BaseModel):
|
||||
type: Literal["json_object"] = "json_object"
|
||||
|
||||
|
||||
OpenAIResponseFormatParam = Annotated[
|
||||
Union[
|
||||
OpenAIResponseFormatText,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIResponseFormatJSONObject,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAITopLogProb(BaseModel):
|
||||
"""The top log probability for a token from an OpenAI-compatible chat completion response.
|
||||
|
@ -561,22 +646,54 @@ class OpenAITokenLogProb(BaseModel):
|
|||
class OpenAIChoiceLogprobs(BaseModel):
|
||||
"""The log probabilities for the tokens in the message from an OpenAI-compatible chat completion response.
|
||||
|
||||
:content: (Optional) The log probabilities for the tokens in the message
|
||||
:refusal: (Optional) The log probabilities for the tokens in the message
|
||||
:param content: (Optional) The log probabilities for the tokens in the message
|
||||
:param refusal: (Optional) The log probabilities for the tokens in the message
|
||||
"""
|
||||
|
||||
content: Optional[List[OpenAITokenLogProb]] = None
|
||||
refusal: Optional[List[OpenAITokenLogProb]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChoiceDelta(BaseModel):
|
||||
"""A delta from an OpenAI-compatible chat completion streaming response.
|
||||
|
||||
:param content: (Optional) The content of the delta
|
||||
:param refusal: (Optional) The refusal of the delta
|
||||
:param role: (Optional) The role of the delta
|
||||
:param tool_calls: (Optional) The tool calls of the delta
|
||||
"""
|
||||
|
||||
content: Optional[str] = None
|
||||
refusal: Optional[str] = None
|
||||
role: Optional[str] = None
|
||||
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChunkChoice(BaseModel):
|
||||
"""A chunk choice from an OpenAI-compatible chat completion streaming response.
|
||||
|
||||
:param delta: The delta from the chunk
|
||||
:param finish_reason: The reason the model stopped generating
|
||||
:param index: The index of the choice
|
||||
:param logprobs: (Optional) The log probabilities for the tokens in the message
|
||||
"""
|
||||
|
||||
delta: OpenAIChoiceDelta
|
||||
finish_reason: str
|
||||
index: int
|
||||
logprobs: Optional[OpenAIChoiceLogprobs] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChoice(BaseModel):
|
||||
"""A choice from an OpenAI-compatible chat completion response.
|
||||
|
||||
:param message: The message from the model
|
||||
:param finish_reason: The reason the model stopped generating
|
||||
:index: The index of the choice
|
||||
:logprobs: (Optional) The log probabilities for the tokens in the message
|
||||
:param index: The index of the choice
|
||||
:param logprobs: (Optional) The log probabilities for the tokens in the message
|
||||
"""
|
||||
|
||||
message: OpenAIMessageParam
|
||||
|
@ -603,6 +720,24 @@ class OpenAIChatCompletion(BaseModel):
|
|||
model: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionChunk(BaseModel):
|
||||
"""Chunk from a streaming response to an OpenAI-compatible chat completion request.
|
||||
|
||||
:param id: The ID of the chat completion
|
||||
:param choices: List of choices
|
||||
:param object: The object type, which will be "chat.completion.chunk"
|
||||
:param created: The Unix timestamp in seconds when the chat completion was created
|
||||
:param model: The model that was used to generate the chat completion
|
||||
"""
|
||||
|
||||
id: str
|
||||
choices: List[OpenAIChunkChoice]
|
||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
created: int
|
||||
model: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAICompletionLogprobs(BaseModel):
|
||||
"""The log probabilities for the tokens in the message from an OpenAI-compatible completion response.
|
||||
|
@ -872,7 +1007,7 @@ class Inference(Protocol):
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
@ -883,7 +1018,7 @@ class Inference(Protocol):
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import List, Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
|
@ -20,8 +21,7 @@ class RouteInfo(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class HealthInfo(BaseModel):
|
||||
status: str
|
||||
# TODO: add a provider level status
|
||||
status: HealthStatus
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any, Dict, List, Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.datatypes import HealthResponse
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
|
@ -17,6 +18,7 @@ class ProviderInfo(BaseModel):
|
|||
provider_id: str
|
||||
provider_type: str
|
||||
config: Dict[str, Any]
|
||||
health: HealthResponse
|
||||
|
||||
|
||||
class ListProvidersResponse(BaseModel):
|
||||
|
|
|
@ -89,6 +89,43 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
color="red",
|
||||
)
|
||||
sys.exit(1)
|
||||
elif args.providers:
|
||||
providers = dict()
|
||||
for api_provider in args.providers.split(","):
|
||||
if "=" not in api_provider:
|
||||
cprint(
|
||||
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
|
||||
color="red",
|
||||
)
|
||||
sys.exit(1)
|
||||
api, provider = api_provider.split("=")
|
||||
providers_for_api = get_provider_registry().get(Api(api), None)
|
||||
if providers_for_api is None:
|
||||
cprint(
|
||||
f"{api} is not a valid API.",
|
||||
color="red",
|
||||
)
|
||||
sys.exit(1)
|
||||
if provider in providers_for_api:
|
||||
providers.setdefault(api, []).append(provider)
|
||||
else:
|
||||
cprint(
|
||||
f"{provider} is not a valid provider for the {api} API.",
|
||||
color="red",
|
||||
)
|
||||
sys.exit(1)
|
||||
distribution_spec = DistributionSpec(
|
||||
providers=providers,
|
||||
description=",".join(args.providers),
|
||||
)
|
||||
if not args.image_type:
|
||||
cprint(
|
||||
f"Please specify a image-type (container | conda | venv) for {args.template}",
|
||||
color="red",
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
build_config = BuildConfig(image_type=args.image_type, distribution_spec=distribution_spec)
|
||||
elif not args.config and not args.template:
|
||||
name = prompt(
|
||||
"> Enter a name for your Llama Stack (e.g. my-local-stack): ",
|
||||
|
|
|
@ -75,6 +75,12 @@ the build. If not specified, currently active environment will be used if found.
|
|||
default=False,
|
||||
help="Run the stack after building using the same image type, name, and other applicable arguments",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--providers",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Build a config for a list of providers and only those providers. This list is formatted like: api1=provider1,api2=provider2. Where there can be multiple providers per API.",
|
||||
)
|
||||
|
||||
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
||||
# always keep implementation completely silo-ed away from CLI so CLI
|
||||
|
|
|
@ -17,6 +17,7 @@ from llama_stack.apis.inspect import (
|
|||
)
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
|
||||
|
||||
class DistributionInspectConfig(BaseModel):
|
||||
|
@ -58,7 +59,7 @@ class DistributionInspectImpl(Inspect):
|
|||
return ListRoutesResponse(data=ret)
|
||||
|
||||
async def health(self) -> HealthInfo:
|
||||
return HealthInfo(status="OK")
|
||||
return HealthInfo(status=HealthStatus.OK)
|
||||
|
||||
async def version(self) -> VersionInfo:
|
||||
return VersionInfo(version=version("llama-stack"))
|
||||
|
|
|
@ -43,9 +43,9 @@ from llama_stack.distribution.server.endpoints import (
|
|||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
get_stack_run_config_from_template,
|
||||
redact_sensitive_fields,
|
||||
replace_env_vars,
|
||||
)
|
||||
from llama_stack.distribution.utils.config import redact_sensitive_fields
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.distribution.utils.exec import in_notebook
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
|
|
|
@ -4,14 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
|
||||
|
||||
from .datatypes import StackRunConfig
|
||||
from .stack import redact_sensitive_fields
|
||||
from .utils.config import redact_sensitive_fields
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
@ -41,19 +44,24 @@ class ProviderImpl(Providers):
|
|||
async def list_providers(self) -> ListProvidersResponse:
|
||||
run_config = self.config.run_config
|
||||
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
|
||||
providers_health = await self.get_providers_health()
|
||||
ret = []
|
||||
for api, providers in safe_config.providers.items():
|
||||
ret.extend(
|
||||
[
|
||||
for p in providers:
|
||||
ret.append(
|
||||
ProviderInfo(
|
||||
api=api,
|
||||
provider_id=p.provider_id,
|
||||
provider_type=p.provider_type,
|
||||
config=p.config,
|
||||
health=providers_health.get(api, {}).get(
|
||||
p.provider_id,
|
||||
HealthResponse(
|
||||
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
|
||||
),
|
||||
),
|
||||
)
|
||||
for p in providers
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return ListProvidersResponse(data=ret)
|
||||
|
||||
|
@ -64,3 +72,57 @@ class ProviderImpl(Providers):
|
|||
return p
|
||||
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
async def get_providers_health(self) -> Dict[str, Dict[str, HealthResponse]]:
|
||||
"""Get health status for all providers.
|
||||
|
||||
Returns:
|
||||
Dict[str, Dict[str, HealthResponse]]: A dictionary mapping API names to provider health statuses.
|
||||
Each API maps to a dictionary of provider IDs to their health responses.
|
||||
"""
|
||||
providers_health: Dict[str, Dict[str, HealthResponse]] = {}
|
||||
timeout = 1.0
|
||||
|
||||
async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None:
|
||||
# Skip special implementations (inspect/providers) that don't have provider specs
|
||||
if not hasattr(impl, "__provider_spec__"):
|
||||
return None
|
||||
api_name = impl.__provider_spec__.api.name
|
||||
if not hasattr(impl, "health"):
|
||||
return (
|
||||
api_name,
|
||||
HealthResponse(
|
||||
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
||||
return api_name, health
|
||||
except asyncio.TimeoutError:
|
||||
return (
|
||||
api_name,
|
||||
HealthResponse(
|
||||
status=HealthStatus.ERROR, message=f"Health check timed out after {timeout} seconds"
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
return (
|
||||
api_name,
|
||||
HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"),
|
||||
)
|
||||
|
||||
# Create tasks for all providers
|
||||
tasks = [check_provider_health(impl) for impl in self.deps.values()]
|
||||
|
||||
# Wait for all health checks to complete
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Organize results by API and provider ID
|
||||
for result in results:
|
||||
if result is None: # Skip special implementations
|
||||
continue
|
||||
api_name, health_response = result
|
||||
providers_health[api_name] = health_response
|
||||
|
||||
return providers_health
|
||||
|
|
|
@ -41,7 +41,6 @@ from llama_stack.providers.datatypes import (
|
|||
Api,
|
||||
BenchmarksProtocolPrivate,
|
||||
DatasetsProtocolPrivate,
|
||||
InlineProviderSpec,
|
||||
ModelsProtocolPrivate,
|
||||
ProviderSpec,
|
||||
RemoteProviderConfig,
|
||||
|
@ -230,46 +229,6 @@ def sort_providers_by_deps(
|
|||
{k: list(v.values()) for k, v in providers_with_specs.items()}
|
||||
)
|
||||
|
||||
# Append built-in "inspect" provider
|
||||
apis = [x[1].spec.api for x in sorted_providers]
|
||||
sorted_providers.append(
|
||||
(
|
||||
"inspect",
|
||||
ProviderWithSpec(
|
||||
provider_id="__builtin__",
|
||||
provider_type="__builtin__",
|
||||
config={"run_config": run_config.model_dump()},
|
||||
spec=InlineProviderSpec(
|
||||
api=Api.inspect,
|
||||
provider_type="__builtin__",
|
||||
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
|
||||
module="llama_stack.distribution.inspect",
|
||||
api_dependencies=apis,
|
||||
deps__=[x.value for x in apis],
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
sorted_providers.append(
|
||||
(
|
||||
"providers",
|
||||
ProviderWithSpec(
|
||||
provider_id="__builtin__",
|
||||
provider_type="__builtin__",
|
||||
config={"run_config": run_config.model_dump()},
|
||||
spec=InlineProviderSpec(
|
||||
api=Api.providers,
|
||||
provider_type="__builtin__",
|
||||
config_class="llama_stack.distribution.providers.ProviderImplConfig",
|
||||
module="llama_stack.distribution.providers",
|
||||
api_dependencies=apis,
|
||||
deps__=[x.value for x in apis],
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"Resolved {len(sorted_providers)} providers")
|
||||
for api_str, provider in sorted_providers:
|
||||
logger.debug(f" {api_str} => {provider.provider_id}")
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
|
@ -37,7 +38,13 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
from llama_stack.apis.scoring import (
|
||||
|
@ -60,7 +67,7 @@ from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
@ -530,7 +537,7 @@ class InferenceRouter(Inference):
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
@ -541,7 +548,7 @@ class InferenceRouter(Inference):
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
logger.debug(
|
||||
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
|
||||
)
|
||||
|
@ -580,6 +587,29 @@ class InferenceRouter(Inference):
|
|||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
return await provider.openai_chat_completion(**params)
|
||||
|
||||
async def health(self) -> Dict[str, HealthResponse]:
|
||||
health_statuses = {}
|
||||
timeout = 0.5
|
||||
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
|
||||
try:
|
||||
# check if the provider has a health method
|
||||
if not hasattr(impl, "health"):
|
||||
continue
|
||||
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
||||
health_statuses[provider_id] = health
|
||||
except asyncio.TimeoutError:
|
||||
health_statuses[provider_id] = HealthResponse(
|
||||
status=HealthStatus.ERROR,
|
||||
message=f"Health check timed out after {timeout} seconds",
|
||||
)
|
||||
except NotImplementedError:
|
||||
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
|
||||
except Exception as e:
|
||||
health_statuses[provider_id] = HealthResponse(
|
||||
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
||||
)
|
||||
return health_statuses
|
||||
|
||||
|
||||
class SafetyRouter(Safety):
|
||||
def __init__(
|
||||
|
|
|
@ -38,10 +38,10 @@ from llama_stack.distribution.server.endpoints import (
|
|||
)
|
||||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
redact_sensitive_fields,
|
||||
replace_env_vars,
|
||||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.distribution.utils.config import redact_sensitive_fields
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
@ -229,15 +229,30 @@ class TracingMiddleware:
|
|||
def __init__(self, app, impls):
|
||||
self.app = app
|
||||
self.impls = impls
|
||||
# FastAPI built-in paths that should bypass custom routing
|
||||
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope.get("type") == "lifespan":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
path = scope.get("path", "")
|
||||
|
||||
# Check if the path is a FastAPI built-in path
|
||||
if path.startswith(self.fastapi_paths):
|
||||
# Pass through to FastAPI's built-in handlers
|
||||
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
if not hasattr(self, "endpoint_impls"):
|
||||
self.endpoint_impls = initialize_endpoint_impls(self.impls)
|
||||
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
|
||||
|
||||
try:
|
||||
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass through to FastAPI
|
||||
logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path})
|
||||
|
||||
|
@ -388,7 +403,12 @@ def main(args: Optional[argparse.Namespace] = None):
|
|||
safe_config = redact_sensitive_fields(config.model_dump())
|
||||
logger.info(yaml.dump(safe_config, indent=2))
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app = FastAPI(
|
||||
lifespan=lifespan,
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
openapi_url="/openapi.json",
|
||||
)
|
||||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||
app.add_middleware(ClientVersionMiddleware)
|
||||
|
||||
|
|
|
@ -35,6 +35,8 @@ from llama_stack.apis.vector_dbs import VectorDBs
|
|||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||
from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig
|
||||
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||
from llama_stack.distribution.store.registry import create_dist_registry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
@ -119,26 +121,6 @@ class EnvVarError(Exception):
|
|||
super().__init__(f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}")
|
||||
|
||||
|
||||
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Redact sensitive information from config before printing."""
|
||||
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
|
||||
|
||||
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
result = {}
|
||||
for k, v in d.items():
|
||||
if isinstance(v, dict):
|
||||
result[k] = _redact_dict(v)
|
||||
elif isinstance(v, list):
|
||||
result[k] = [_redact_dict(i) if isinstance(i, dict) else i for i in v]
|
||||
elif any(pattern in k.lower() for pattern in sensitive_patterns):
|
||||
result[k] = "********"
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
return _redact_dict(data)
|
||||
|
||||
|
||||
def replace_env_vars(config: Any, path: str = "") -> Any:
|
||||
if isinstance(config, dict):
|
||||
result = {}
|
||||
|
@ -215,6 +197,26 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
|
|||
) from e
|
||||
|
||||
|
||||
def add_internal_implementations(impls: Dict[Api, Any], run_config: StackRunConfig) -> None:
|
||||
"""Add internal implementations (inspect and providers) to the implementations dictionary.
|
||||
|
||||
Args:
|
||||
impls: Dictionary of API implementations
|
||||
run_config: Stack run configuration
|
||||
"""
|
||||
inspect_impl = DistributionInspectImpl(
|
||||
DistributionInspectConfig(run_config=run_config),
|
||||
deps=impls,
|
||||
)
|
||||
impls[Api.inspect] = inspect_impl
|
||||
|
||||
providers_impl = ProviderImpl(
|
||||
ProviderImplConfig(run_config=run_config),
|
||||
deps=impls,
|
||||
)
|
||||
impls[Api.providers] = providers_impl
|
||||
|
||||
|
||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||
# asked for in the run config.
|
||||
async def construct_stack(
|
||||
|
@ -222,6 +224,10 @@ async def construct_stack(
|
|||
) -> Dict[Api, Any]:
|
||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
|
||||
|
||||
# Add internal implementations after all other providers are resolved
|
||||
add_internal_implementations(impls, run_config)
|
||||
|
||||
await register_resources(run_config, impls)
|
||||
return impls
|
||||
|
||||
|
|
|
@ -56,6 +56,17 @@ def tool_chat_page():
|
|||
st.subheader(f"Active Tools: 🛠 {len(active_tool_list)}")
|
||||
st.json(active_tool_list)
|
||||
|
||||
st.subheader("Chat Configurations")
|
||||
max_tokens = st.slider(
|
||||
"Max Tokens",
|
||||
min_value=0,
|
||||
max_value=4096,
|
||||
value=512,
|
||||
step=1,
|
||||
help="The maximum number of tokens to generate",
|
||||
on_change=reset_agent,
|
||||
)
|
||||
|
||||
@st.cache_resource
|
||||
def create_agent():
|
||||
return Agent(
|
||||
|
@ -63,9 +74,7 @@ def tool_chat_page():
|
|||
model=model,
|
||||
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
|
||||
tools=toolgroup_selection,
|
||||
sampling_params={
|
||||
"strategy": {"type": "greedy"},
|
||||
},
|
||||
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
||||
)
|
||||
|
||||
agent = create_agent()
|
||||
|
|
30
llama_stack/distribution/utils/config.py
Normal file
30
llama_stack/distribution/utils/config.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Redact sensitive information from config before printing."""
|
||||
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
|
||||
|
||||
def _redact_value(v: Any) -> Any:
|
||||
if isinstance(v, dict):
|
||||
return _redact_dict(v)
|
||||
elif isinstance(v, list):
|
||||
return [_redact_value(i) for i in v]
|
||||
return v
|
||||
|
||||
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
result = {}
|
||||
for k, v in d.items():
|
||||
if any(pattern in k.lower() for pattern in sensitive_patterns):
|
||||
result[k] = "********"
|
||||
else:
|
||||
result[k] = _redact_value(v)
|
||||
return result
|
||||
|
||||
return _redact_dict(data)
|
|
@ -204,7 +204,9 @@ class ToolUtils:
|
|||
return None
|
||||
elif is_json(message_body):
|
||||
response = json.loads(message_body)
|
||||
if ("type" in response and response["type"] == "function") or ("name" in response):
|
||||
if ("type" in response and response["type"] == "function") or (
|
||||
"name" in response and "parameters" in response
|
||||
):
|
||||
function_name = response["name"]
|
||||
args = response["parameters"]
|
||||
return function_name, args
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional, Protocol
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
@ -201,3 +202,12 @@ def remote_provider_spec(
|
|||
adapter=adapter,
|
||||
api_dependencies=api_dependencies or [],
|
||||
)
|
||||
|
||||
|
||||
class HealthStatus(str, Enum):
|
||||
OK = "OK"
|
||||
ERROR = "Error"
|
||||
NOT_IMPLEMENTED = "Not Implemented"
|
||||
|
||||
|
||||
HealthResponse = dict[str, Any]
|
||||
|
|
|
@ -59,8 +59,8 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
augment_content_with_response_format_prompt,
|
||||
|
@ -83,8 +83,8 @@ def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_
|
|||
|
||||
|
||||
class MetaReferenceInferenceImpl(
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
Inference,
|
||||
ModelsProtocolPrivate,
|
||||
|
|
|
@ -25,8 +25,8 @@ from llama_stack.providers.utils.inference.embedding_mixin import (
|
|||
SentenceTransformerEmbeddingMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
)
|
||||
|
||||
from .config import SentenceTransformersInferenceConfig
|
||||
|
@ -35,8 +35,8 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class SentenceTransformersInferenceImpl(
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
Inference,
|
||||
ModelsProtocolPrivate,
|
||||
|
|
|
@ -66,10 +66,10 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
ModelsProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_stop_reason,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
|
@ -176,8 +176,8 @@ def _convert_sampling_params(
|
|||
|
||||
class VLLMInferenceImpl(
|
||||
Inference,
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
"""
|
||||
|
|
|
@ -3,13 +3,14 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
AlgorithmConfig,
|
||||
Checkpoint,
|
||||
DPOAlignmentConfig,
|
||||
JobStatus,
|
||||
ListPostTrainingJobsResponse,
|
||||
|
@ -25,9 +26,19 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
|
|||
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
|
||||
LoraFinetuningSingleDevice,
|
||||
)
|
||||
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
||||
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
||||
|
||||
class TrainingArtifactType(Enum):
|
||||
CHECKPOINT = "checkpoint"
|
||||
RESOURCES_STATS = "resources_stats"
|
||||
|
||||
|
||||
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
|
||||
|
||||
|
||||
class TorchtunePostTrainingImpl:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -38,13 +49,27 @@ class TorchtunePostTrainingImpl:
|
|||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets
|
||||
self._scheduler = Scheduler()
|
||||
|
||||
# TODO: assume sync job, will need jobs API for async scheduling
|
||||
self.jobs = {}
|
||||
self.checkpoints_dict = {}
|
||||
async def shutdown(self) -> None:
|
||||
await self._scheduler.shutdown()
|
||||
|
||||
async def shutdown(self):
|
||||
pass
|
||||
@staticmethod
|
||||
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
|
||||
return JobArtifact(
|
||||
type=TrainingArtifactType.CHECKPOINT.value,
|
||||
name=checkpoint.identifier,
|
||||
uri=checkpoint.path,
|
||||
metadata=dict(checkpoint),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resources_stats_to_artifact(resources_stats: Dict[str, Any]) -> JobArtifact:
|
||||
return JobArtifact(
|
||||
type=TrainingArtifactType.RESOURCES_STATS.value,
|
||||
name=TrainingArtifactType.RESOURCES_STATS.value,
|
||||
metadata=resources_stats,
|
||||
)
|
||||
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
|
@ -56,20 +81,11 @@ class TorchtunePostTrainingImpl:
|
|||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[AlgorithmConfig],
|
||||
) -> PostTrainingJob:
|
||||
if job_uuid in self.jobs:
|
||||
raise ValueError(f"Job {job_uuid} already exists")
|
||||
|
||||
post_training_job = PostTrainingJob(job_uuid=job_uuid)
|
||||
|
||||
job_status_response = PostTrainingJobStatusResponse(
|
||||
job_uuid=job_uuid,
|
||||
status=JobStatus.scheduled,
|
||||
scheduled_at=datetime.now(timezone.utc),
|
||||
)
|
||||
self.jobs[job_uuid] = job_status_response
|
||||
|
||||
if isinstance(algorithm_config, LoraFinetuningConfig):
|
||||
try:
|
||||
|
||||
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
||||
on_log_message_cb("Starting Lora finetuning")
|
||||
|
||||
recipe = LoraFinetuningSingleDevice(
|
||||
self.config,
|
||||
job_uuid,
|
||||
|
@ -82,26 +98,22 @@ class TorchtunePostTrainingImpl:
|
|||
self.datasetio_api,
|
||||
self.datasets_api,
|
||||
)
|
||||
|
||||
job_status_response.status = JobStatus.in_progress
|
||||
job_status_response.started_at = datetime.now(timezone.utc)
|
||||
|
||||
await recipe.setup()
|
||||
|
||||
resources_allocated, checkpoints = await recipe.train()
|
||||
|
||||
self.checkpoints_dict[job_uuid] = checkpoints
|
||||
job_status_response.resources_allocated = resources_allocated
|
||||
job_status_response.checkpoints = checkpoints
|
||||
job_status_response.status = JobStatus.completed
|
||||
job_status_response.completed_at = datetime.now(timezone.utc)
|
||||
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
|
||||
for checkpoint in checkpoints:
|
||||
artifact = self._checkpoint_to_artifact(checkpoint)
|
||||
on_artifact_collected_cb(artifact)
|
||||
|
||||
except Exception:
|
||||
job_status_response.status = JobStatus.failed
|
||||
raise
|
||||
on_status_change_cb(SchedulerJobStatus.completed)
|
||||
on_log_message_cb("Lora finetuning completed")
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return post_training_job
|
||||
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
|
||||
return PostTrainingJob(job_uuid=job_uuid)
|
||||
|
||||
async def preference_optimize(
|
||||
self,
|
||||
|
@ -114,19 +126,55 @@ class TorchtunePostTrainingImpl:
|
|||
) -> PostTrainingJob: ...
|
||||
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||
return ListPostTrainingJobsResponse(data=[PostTrainingJob(job_uuid=uuid_) for uuid_ in self.jobs])
|
||||
return ListPostTrainingJobsResponse(
|
||||
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_artifacts_metadata_by_type(job, artifact_type):
|
||||
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
|
||||
|
||||
@classmethod
|
||||
def _get_checkpoints(cls, job):
|
||||
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
|
||||
|
||||
@classmethod
|
||||
def _get_resources_allocated(cls, job):
|
||||
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
|
||||
return data[0] if data else None
|
||||
|
||||
@webmethod(route="/post-training/job/status")
|
||||
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
|
||||
return self.jobs.get(job_uuid, None)
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
|
||||
match job.status:
|
||||
# TODO: Add support for other statuses to API
|
||||
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
|
||||
status = JobStatus.scheduled
|
||||
case SchedulerJobStatus.running:
|
||||
status = JobStatus.in_progress
|
||||
case SchedulerJobStatus.completed:
|
||||
status = JobStatus.completed
|
||||
case SchedulerJobStatus.failed:
|
||||
status = JobStatus.failed
|
||||
case _:
|
||||
raise NotImplementedError()
|
||||
|
||||
return PostTrainingJobStatusResponse(
|
||||
job_uuid=job_uuid,
|
||||
status=status,
|
||||
scheduled_at=job.scheduled_at,
|
||||
started_at=job.started_at,
|
||||
completed_at=job.completed_at,
|
||||
checkpoints=self._get_checkpoints(job),
|
||||
resources_allocated=self._get_resources_allocated(job),
|
||||
)
|
||||
|
||||
@webmethod(route="/post-training/job/cancel")
|
||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||
raise NotImplementedError("Job cancel is not implemented yet")
|
||||
self._scheduler.cancel(job_uuid)
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts")
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
|
||||
if job_uuid in self.checkpoints_dict:
|
||||
checkpoints = self.checkpoints_dict.get(job_uuid, [])
|
||||
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=checkpoints)
|
||||
return None
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
|
||||
|
|
|
@ -36,10 +36,10 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_sampling_strategy_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
|
@ -56,8 +56,8 @@ from .models import MODEL_ENTRIES
|
|||
class BedrockInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
):
|
||||
def __init__(self, config: BedrockConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
|
|
|
@ -34,8 +34,8 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
|
@ -54,8 +54,8 @@ from .models import MODEL_ENTRIES
|
|||
class CerebrasInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
):
|
||||
def __init__(self, config: CerebrasImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
|
|
|
@ -34,8 +34,8 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
|
@ -61,8 +61,8 @@ model_entries = [
|
|||
class DatabricksInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
):
|
||||
def __init__(self, config: DatabricksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, model_entries=model_entries)
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from fireworks.client import Fireworks
|
||||
from openai import AsyncOpenAI
|
||||
|
@ -32,13 +32,20 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
prepare_openai_completion_params,
|
||||
|
@ -301,6 +308,11 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
prompt_logprobs: Optional[int] = None,
|
||||
) -> OpenAICompletion:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
|
||||
# Fireworks always prepends with BOS
|
||||
if isinstance(prompt, str) and prompt.startswith("<|begin_of_text|>"):
|
||||
prompt = prompt[len("<|begin_of_text|>") :]
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
prompt=prompt,
|
||||
|
@ -320,6 +332,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
return await self._get_openai_client().completions.create(**params)
|
||||
|
||||
async def openai_chat_completion(
|
||||
|
@ -336,7 +349,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
@ -347,10 +360,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
|
@ -374,4 +386,12 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
return await self._get_openai_client().chat.completions.create(**params)
|
||||
|
||||
# Divert Llama Models through Llama Stack inference APIs because
|
||||
# Fireworks chat completions OpenAI-compatible API does not support
|
||||
# tool calls properly.
|
||||
llama_model = self.get_llama_model(model_obj.provider_resource_id)
|
||||
if llama_model:
|
||||
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(self, model=model, **params)
|
||||
|
||||
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)
|
||||
|
|
|
@ -4,8 +4,24 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChoiceDelta,
|
||||
OpenAIChunkChoice,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
OpenAISystemMessageParam,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
prepare_openai_completion_params,
|
||||
)
|
||||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
@ -21,9 +37,129 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
provider_data_api_key_field="groq_api_key",
|
||||
)
|
||||
self.config = config
|
||||
self._openai_client = None
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
||||
if self._openai_client:
|
||||
await self._openai_client.close()
|
||||
self._openai_client = None
|
||||
|
||||
def _get_openai_client(self) -> AsyncOpenAI:
|
||||
if not self._openai_client:
|
||||
self._openai_client = AsyncOpenAI(
|
||||
base_url=f"{self.config.url}/openai/v1",
|
||||
api_key=self.config.api_key,
|
||||
)
|
||||
return self._openai_client
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[OpenAIMessageParam],
|
||||
frequency_penalty: Optional[float] = None,
|
||||
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
functions: Optional[List[Dict[str, Any]]] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[Dict[str, Any]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
|
||||
# Groq does not support json_schema response format, so we need to convert it to json_object
|
||||
if response_format and response_format.type == "json_schema":
|
||||
response_format.type = "json_object"
|
||||
schema = response_format.json_schema.get("schema", {})
|
||||
response_format.json_schema = None
|
||||
json_instructions = f"\nYour response should be a JSON object that matches the following schema: {schema}"
|
||||
if messages and messages[0].role == "system":
|
||||
messages[0].content = messages[0].content + json_instructions
|
||||
else:
|
||||
messages.insert(0, OpenAISystemMessageParam(content=json_instructions))
|
||||
|
||||
# Groq returns a 400 error if tools are provided but none are called
|
||||
# So, set tool_choice to "required" to attempt to force a call
|
||||
if tools and (not tool_choice or tool_choice == "auto"):
|
||||
tool_choice = "required"
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id.replace("groq/", ""),
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Groq does not support streaming requests that set response_format
|
||||
fake_stream = False
|
||||
if stream and response_format:
|
||||
params["stream"] = False
|
||||
fake_stream = True
|
||||
|
||||
response = await self._get_openai_client().chat.completions.create(**params)
|
||||
|
||||
if fake_stream:
|
||||
chunk_choices = []
|
||||
for choice in response.choices:
|
||||
delta = OpenAIChoiceDelta(
|
||||
content=choice.message.content,
|
||||
role=choice.message.role,
|
||||
tool_calls=choice.message.tool_calls,
|
||||
)
|
||||
chunk_choice = OpenAIChunkChoice(
|
||||
delta=delta,
|
||||
finish_reason=choice.finish_reason,
|
||||
index=choice.index,
|
||||
logprobs=None,
|
||||
)
|
||||
chunk_choices.append(chunk_choice)
|
||||
chunk = OpenAIChatCompletionChunk(
|
||||
id=response.id,
|
||||
choices=chunk_choices,
|
||||
object="chat.completion.chunk",
|
||||
created=response.created,
|
||||
model=response.model,
|
||||
)
|
||||
|
||||
async def _fake_stream_generator():
|
||||
yield chunk
|
||||
|
||||
return _fake_stream_generator()
|
||||
else:
|
||||
return response
|
||||
|
|
|
@ -39,8 +39,16 @@ MODEL_ENTRIES = [
|
|||
"groq/llama-4-scout-17b-16e-instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/meta-llama/llama-4-scout-17b-16e-instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama-4-maverick-17b-128e-instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/meta-llama/llama-4-maverick-17b-128e-instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
]
|
||||
|
|
|
@ -34,15 +34,18 @@ from llama_stack.apis.inference import (
|
|||
ToolChoice,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.utils.inference import (
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import ToolPromptFormat
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
|
@ -335,7 +338,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
@ -346,7 +349,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
provider_model_id = self.get_provider_model_id(model)
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from ollama import AsyncClient
|
||||
|
@ -39,10 +39,20 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
ModelsProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
|
@ -87,8 +97,19 @@ class OllamaInferenceAdapter(
|
|||
|
||||
async def initialize(self) -> None:
|
||||
logger.info(f"checking connectivity to Ollama at `{self.url}`...")
|
||||
await self.health()
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
"""
|
||||
Performs a health check by verifying connectivity to the Ollama server.
|
||||
This method is used by initialize() and the Provider API to verify that the service is running
|
||||
correctly.
|
||||
Returns:
|
||||
HealthResponse: A dictionary containing the health status.
|
||||
"""
|
||||
try:
|
||||
await self.client.ps()
|
||||
return HealthResponse(status=HealthStatus.OK)
|
||||
except httpx.ConnectError as e:
|
||||
raise RuntimeError(
|
||||
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
|
||||
|
@ -322,6 +343,12 @@ class OllamaInferenceAdapter(
|
|||
response = await self.client.list()
|
||||
available_models = [m["model"] for m in response["models"]]
|
||||
if model.provider_resource_id not in available_models:
|
||||
available_models_latest = [m["model"].split(":latest")[0] for m in response["models"]]
|
||||
if model.provider_resource_id in available_models_latest:
|
||||
logger.warning(
|
||||
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
|
||||
)
|
||||
return model
|
||||
raise ValueError(
|
||||
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
|
||||
)
|
||||
|
@ -393,7 +420,7 @@ class OllamaInferenceAdapter(
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
@ -404,7 +431,7 @@ class OllamaInferenceAdapter(
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
model_obj = await self._get_model(model)
|
||||
params = {
|
||||
k: v
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from llama_stack_client import AsyncLlamaStackClient
|
||||
|
||||
|
@ -26,7 +26,13 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.distribution.library_client import convert_pydantic_to_json_value, convert_to_pydantic
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
@ -266,7 +272,7 @@ class PassthroughInferenceAdapter(Inference):
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
@ -277,7 +283,7 @@ class PassthroughInferenceAdapter(Inference):
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
client = self._get_client()
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
|
||||
|
|
|
@ -12,8 +12,8 @@ from llama_stack.apis.inference import * # noqa: F403
|
|||
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
|
@ -43,8 +43,8 @@ RUNPOD_SUPPORTED_MODELS = {
|
|||
class RunpodInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
):
|
||||
def __init__(self, config: RunpodImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
|
||||
|
|
|
@ -42,8 +42,8 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
|
@ -57,8 +57,8 @@ from .models import MODEL_ENTRIES
|
|||
class SambaNovaInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
):
|
||||
def __init__(self, config: SambaNovaImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
|
|
|
@ -40,10 +40,10 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
|
@ -73,8 +73,8 @@ def build_hf_repo_model_entries():
|
|||
|
||||
class _HfAdapter(
|
||||
Inference,
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
client: AsyncInferenceClient
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from together import AsyncTogether
|
||||
|
@ -31,7 +31,13 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
@ -315,7 +321,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
@ -326,7 +332,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
|
@ -353,4 +359,26 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
if params.get("stream", True):
|
||||
return self._stream_openai_chat_completion(params)
|
||||
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
|
||||
|
||||
async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
|
||||
# together.ai sometimes adds usage data to the stream, even if include_usage is False
|
||||
# This causes an unexpected final chunk with empty choices array to be sent
|
||||
# to clients that may not handle it gracefully.
|
||||
include_usage = False
|
||||
if params.get("stream_options", None):
|
||||
include_usage = params["stream_options"].get("include_usage", False)
|
||||
stream = await self._get_openai_client().chat.completions.create(**params)
|
||||
|
||||
seen_finish_reason = False
|
||||
async for chunk in stream:
|
||||
# Final usage chunk with no choices that the user didn't request, so discard
|
||||
if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
|
||||
break
|
||||
yield chunk
|
||||
for choice in chunk.choices:
|
||||
if choice.finish_reason:
|
||||
seen_finish_reason = True
|
||||
break
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from openai import AsyncOpenAI
|
||||
|
@ -45,7 +45,12 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
|
@ -369,7 +374,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
options["max_tokens"] = self.config.max_tokens
|
||||
|
||||
input_dict: dict[str, Any] = {}
|
||||
if isinstance(request, ChatCompletionRequest) and request.tools is not None:
|
||||
# Only include the 'tools' param if there is any. It can break things if an empty list is sent to the vLLM.
|
||||
if isinstance(request, ChatCompletionRequest) and request.tools:
|
||||
input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)}
|
||||
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
|
@ -487,7 +493,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
@ -498,7 +504,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
model_obj = await self._get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
|
|
|
@ -30,7 +30,13 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models.models import Model
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
|
@ -270,7 +276,7 @@ class LiteLLMOpenAIMixin(
|
|||
guided_choice: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
) -> OpenAICompletion:
|
||||
model_obj = await self._get_model(model)
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
prompt=prompt,
|
||||
|
@ -292,7 +298,7 @@ class LiteLLMOpenAIMixin(
|
|||
guided_choice=guided_choice,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
return litellm.text_completion(**params)
|
||||
return await litellm.atext_completion(**params)
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
|
@ -308,7 +314,7 @@ class LiteLLMOpenAIMixin(
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
@ -319,8 +325,8 @@ class LiteLLMOpenAIMixin(
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
model_obj = await self._get_model(model)
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
messages=messages,
|
||||
|
@ -346,7 +352,7 @@ class LiteLLMOpenAIMixin(
|
|||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
return litellm.completion(**params)
|
||||
return await litellm.acompletion(**params)
|
||||
|
||||
async def batch_completion(
|
||||
self,
|
||||
|
|
|
@ -8,7 +8,7 @@ import logging
|
|||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable, List, Optional, Union
|
||||
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat import (
|
||||
|
@ -50,6 +50,18 @@ from openai.types.chat.chat_completion import (
|
|||
from openai.types.chat.chat_completion import (
|
||||
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
Choice as OpenAIChatCompletionChunkChoice,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChoiceDelta as OpenAIChoiceDelta,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction,
|
||||
)
|
||||
from openai.types.chat.chat_completion_content_part_image_param import (
|
||||
ImageURL as OpenAIImageURL,
|
||||
)
|
||||
|
@ -59,6 +71,7 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
ImageContentItem,
|
||||
InterleavedContent,
|
||||
TextContentItem,
|
||||
|
@ -85,12 +98,24 @@ from llama_stack.apis.inference import (
|
|||
TopPSamplingStrategy,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAICompletionChoice
|
||||
from llama_stack.apis.inference.inference import (
|
||||
JsonSchemaResponseFormat,
|
||||
OpenAIChatCompletion,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionChoice,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChoice as OpenAIChatCompletionChoice,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
convert_image_content_to_url,
|
||||
|
@ -751,6 +776,17 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
|
|||
return out
|
||||
|
||||
|
||||
def _convert_stop_reason_to_openai_finish_reason(stop_reason: StopReason) -> str:
|
||||
"""
|
||||
Convert a StopReason to an OpenAI chat completion finish_reason.
|
||||
"""
|
||||
return {
|
||||
StopReason.end_of_turn: "stop",
|
||||
StopReason.end_of_message: "tool_calls",
|
||||
StopReason.out_of_tokens: "length",
|
||||
}.get(stop_reason, "stop")
|
||||
|
||||
|
||||
def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
|
||||
"""
|
||||
Convert an OpenAI chat completion finish_reason to a StopReason.
|
||||
|
@ -776,6 +812,56 @@ def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
|
|||
}.get(finish_reason, StopReason.end_of_turn)
|
||||
|
||||
|
||||
def _convert_openai_request_tool_config(tool_choice: Optional[Union[str, Dict[str, Any]]] = None) -> ToolConfig:
|
||||
tool_config = ToolConfig()
|
||||
if tool_choice:
|
||||
tool_config.tool_choice = tool_choice
|
||||
return tool_config
|
||||
|
||||
|
||||
def _convert_openai_request_tools(tools: Optional[List[Dict[str, Any]]] = None) -> List[ToolDefinition]:
|
||||
lls_tools = []
|
||||
if not tools:
|
||||
return lls_tools
|
||||
|
||||
for tool in tools:
|
||||
tool_fn = tool.get("function", {})
|
||||
tool_name = tool_fn.get("name", None)
|
||||
tool_desc = tool_fn.get("description", None)
|
||||
|
||||
tool_params = tool_fn.get("parameters", None)
|
||||
lls_tool_params = {}
|
||||
if tool_params is not None:
|
||||
tool_param_properties = tool_params.get("properties", {})
|
||||
for tool_param_key, tool_param_value in tool_param_properties.items():
|
||||
tool_param_def = ToolParamDefinition(
|
||||
param_type=tool_param_value.get("type", None),
|
||||
description=tool_param_value.get("description", None),
|
||||
)
|
||||
lls_tool_params[tool_param_key] = tool_param_def
|
||||
|
||||
lls_tool = ToolDefinition(
|
||||
tool_name=tool_name,
|
||||
description=tool_desc,
|
||||
parameters=lls_tool_params,
|
||||
)
|
||||
lls_tools.append(lls_tool)
|
||||
return lls_tools
|
||||
|
||||
|
||||
def _convert_openai_request_response_format(response_format: OpenAIResponseFormatParam = None):
|
||||
if not response_format:
|
||||
return None
|
||||
# response_format can be a dict or a pydantic model
|
||||
response_format = dict(response_format)
|
||||
if response_format.get("type", "") == "json_schema":
|
||||
return JsonSchemaResponseFormat(
|
||||
type="json_schema",
|
||||
json_schema=response_format.get("json_schema", {}).get("schema", ""),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _convert_openai_tool_calls(
|
||||
tool_calls: List[OpenAIChatCompletionMessageToolCall],
|
||||
) -> List[ToolCall]:
|
||||
|
@ -871,6 +957,40 @@ def _convert_openai_sampling_params(
|
|||
return sampling_params
|
||||
|
||||
|
||||
def _convert_openai_request_messages(messages: List[OpenAIMessageParam]):
|
||||
# Llama Stack messages and OpenAI messages are similar, but not identical.
|
||||
lls_messages = []
|
||||
for message in messages:
|
||||
lls_message = dict(message)
|
||||
|
||||
# Llama Stack expects `call_id` but OpenAI uses `tool_call_id`
|
||||
tool_call_id = lls_message.pop("tool_call_id", None)
|
||||
if tool_call_id:
|
||||
lls_message["call_id"] = tool_call_id
|
||||
|
||||
content = lls_message.get("content", None)
|
||||
if isinstance(content, list):
|
||||
lls_content = []
|
||||
for item in content:
|
||||
# items can either by pydantic models or dicts here...
|
||||
item = dict(item)
|
||||
if item.get("type", "") == "image_url":
|
||||
lls_item = ImageContentItem(
|
||||
type="image",
|
||||
image=URL(uri=item.get("image_url", {}).get("url", "")),
|
||||
)
|
||||
elif item.get("type", "") == "text":
|
||||
lls_item = TextContentItem(
|
||||
type="text",
|
||||
text=item.get("text", ""),
|
||||
)
|
||||
lls_content.append(lls_item)
|
||||
lls_message["content"] = lls_content
|
||||
lls_messages.append(lls_message)
|
||||
|
||||
return lls_messages
|
||||
|
||||
|
||||
def convert_openai_chat_completion_choice(
|
||||
choice: OpenAIChoice,
|
||||
) -> ChatCompletionResponse:
|
||||
|
@ -1080,11 +1200,24 @@ async def convert_openai_chat_completion_stream(
|
|||
|
||||
|
||||
async def prepare_openai_completion_params(**params):
|
||||
completion_params = {k: v for k, v in params.items() if v is not None}
|
||||
async def _prepare_value(value: Any) -> Any:
|
||||
new_value = value
|
||||
if isinstance(value, list):
|
||||
new_value = [await _prepare_value(v) for v in value]
|
||||
elif isinstance(value, dict):
|
||||
new_value = {k: await _prepare_value(v) for k, v in value.items()}
|
||||
elif isinstance(value, BaseModel):
|
||||
new_value = value.model_dump(exclude_none=True)
|
||||
return new_value
|
||||
|
||||
completion_params = {}
|
||||
for k, v in params.items():
|
||||
if v is not None:
|
||||
completion_params[k] = await _prepare_value(v)
|
||||
return completion_params
|
||||
|
||||
|
||||
class OpenAICompletionUnsupportedMixin:
|
||||
class OpenAICompletionToLlamaStackMixin:
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -1122,6 +1255,7 @@ class OpenAICompletionUnsupportedMixin:
|
|||
|
||||
choices = []
|
||||
# "n" is the number of completions to generate per prompt
|
||||
n = n or 1
|
||||
for _i in range(0, n):
|
||||
# and we may have multiple prompts, if batching was used
|
||||
|
||||
|
@ -1134,7 +1268,7 @@ class OpenAICompletionUnsupportedMixin:
|
|||
|
||||
index = len(choices)
|
||||
text = result.content
|
||||
finish_reason = _convert_openai_finish_reason(result.stop_reason)
|
||||
finish_reason = _convert_stop_reason_to_openai_finish_reason(result.stop_reason)
|
||||
|
||||
choice = OpenAICompletionChoice(
|
||||
index=index,
|
||||
|
@ -1152,7 +1286,7 @@ class OpenAICompletionUnsupportedMixin:
|
|||
)
|
||||
|
||||
|
||||
class OpenAIChatCompletionUnsupportedMixin:
|
||||
class OpenAIChatCompletionToLlamaStackMixin:
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -1167,7 +1301,7 @@ class OpenAIChatCompletionUnsupportedMixin:
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
@ -1178,5 +1312,103 @@ class OpenAIChatCompletionUnsupportedMixin:
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
messages = _convert_openai_request_messages(messages)
|
||||
response_format = _convert_openai_request_response_format(response_format)
|
||||
sampling_params = _convert_openai_sampling_params(
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
)
|
||||
tool_config = _convert_openai_request_tool_config(tool_choice)
|
||||
tools = _convert_openai_request_tools(tools)
|
||||
|
||||
outstanding_responses = []
|
||||
# "n" is the number of completions to generate per prompt
|
||||
n = n or 1
|
||||
for _i in range(0, n):
|
||||
response = self.chat_completion(
|
||||
model_id=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
tool_config=tool_config,
|
||||
tools=tools,
|
||||
)
|
||||
outstanding_responses.append(response)
|
||||
|
||||
if stream:
|
||||
return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses)
|
||||
|
||||
return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response(
|
||||
self, model, outstanding_responses
|
||||
)
|
||||
|
||||
async def _process_stream_response(
|
||||
self, model: str, outstanding_responses: List[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]]
|
||||
):
|
||||
id = f"chatcmpl-{uuid.uuid4()}"
|
||||
for outstanding_response in outstanding_responses:
|
||||
response = await outstanding_response
|
||||
i = 0
|
||||
async for chunk in response:
|
||||
event = chunk.event
|
||||
finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
|
||||
|
||||
if isinstance(event.delta, TextDelta):
|
||||
text_delta = event.delta.text
|
||||
delta = OpenAIChoiceDelta(content=text_delta)
|
||||
yield OpenAIChatCompletionChunk(
|
||||
id=id,
|
||||
choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)],
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
)
|
||||
elif isinstance(event.delta, ToolCallDelta):
|
||||
if event.delta.parse_status == ToolCallParseStatus.succeeded:
|
||||
tool_call = event.delta.tool_call
|
||||
openai_tool_call = OpenAIChoiceDeltaToolCall(
|
||||
index=0,
|
||||
id=tool_call.call_id,
|
||||
function=OpenAIChoiceDeltaToolCallFunction(
|
||||
name=tool_call.tool_name, arguments=tool_call.arguments_json
|
||||
),
|
||||
)
|
||||
delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call])
|
||||
yield OpenAIChatCompletionChunk(
|
||||
id=id,
|
||||
choices=[
|
||||
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
|
||||
],
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
)
|
||||
i = i + 1
|
||||
|
||||
async def _process_non_stream_response(
|
||||
self, model: str, outstanding_responses: List[Awaitable[ChatCompletionResponse]]
|
||||
) -> OpenAIChatCompletion:
|
||||
raise ValueError(f"{self.__class__.__name__} doesn't support openai chat completion")
|
||||
choices = []
|
||||
for outstanding_response in outstanding_responses:
|
||||
response = await outstanding_response
|
||||
completion_message = response.completion_message
|
||||
message = await convert_message_to_openai_dict_new(completion_message)
|
||||
finish_reason = _convert_stop_reason_to_openai_finish_reason(completion_message.stop_reason)
|
||||
|
||||
choice = OpenAIChatCompletionChoice(
|
||||
index=len(choices),
|
||||
message=message,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
return OpenAIChatCompletion(
|
||||
id=f"chatcmpl-{uuid.uuid4()}",
|
||||
choices=choices,
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
)
|
||||
|
|
265
llama_stack/providers/utils/scheduler.py
Normal file
265
llama_stack/providers/utils/scheduler.py
Normal file
|
@ -0,0 +1,265 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
import functools
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Coroutine, Dict, Iterable, Tuple, TypeAlias
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="scheduler")
|
||||
|
||||
|
||||
# TODO: revisit the list of possible statuses when defining a more coherent
|
||||
# Jobs API for all API flows; e.g. do we need new vs scheduled?
|
||||
class JobStatus(Enum):
|
||||
new = "new"
|
||||
scheduled = "scheduled"
|
||||
running = "running"
|
||||
failed = "failed"
|
||||
completed = "completed"
|
||||
|
||||
|
||||
JobID: TypeAlias = str
|
||||
JobType: TypeAlias = str
|
||||
|
||||
|
||||
class JobArtifact(BaseModel):
|
||||
type: JobType
|
||||
name: str
|
||||
# TODO: uri should be a reference to /files API; revisit when /files is implemented
|
||||
uri: str | None = None
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
JobHandler = Callable[
|
||||
[Callable[[str], None], Callable[[JobStatus], None], Callable[[JobArtifact], None]], Coroutine[Any, Any, None]
|
||||
]
|
||||
|
||||
|
||||
LogMessage: TypeAlias = Tuple[datetime, str]
|
||||
|
||||
|
||||
_COMPLETED_STATUSES = {JobStatus.completed, JobStatus.failed}
|
||||
|
||||
|
||||
class Job:
|
||||
def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler):
|
||||
super().__init__()
|
||||
self.id = job_id
|
||||
self._type = job_type
|
||||
self._handler = handler
|
||||
self._artifacts: list[JobArtifact] = []
|
||||
self._logs: list[LogMessage] = []
|
||||
self._state_transitions: list[Tuple[datetime, JobStatus]] = [(datetime.now(timezone.utc), JobStatus.new)]
|
||||
|
||||
@property
|
||||
def handler(self) -> JobHandler:
|
||||
return self._handler
|
||||
|
||||
@property
|
||||
def status(self) -> JobStatus:
|
||||
return self._state_transitions[-1][1]
|
||||
|
||||
@status.setter
|
||||
def status(self, status: JobStatus):
|
||||
if status in _COMPLETED_STATUSES and self.status in _COMPLETED_STATUSES:
|
||||
raise ValueError(f"Job is already in a completed state ({self.status})")
|
||||
if self.status == status:
|
||||
return
|
||||
self._state_transitions.append((datetime.now(timezone.utc), status))
|
||||
|
||||
@property
|
||||
def artifacts(self) -> list[JobArtifact]:
|
||||
return self._artifacts
|
||||
|
||||
def register_artifact(self, artifact: JobArtifact) -> None:
|
||||
self._artifacts.append(artifact)
|
||||
|
||||
def _find_state_transition_date(self, status: Iterable[JobStatus]) -> datetime | None:
|
||||
for date, s in reversed(self._state_transitions):
|
||||
if s in status:
|
||||
return date
|
||||
return None
|
||||
|
||||
@property
|
||||
def scheduled_at(self) -> datetime | None:
|
||||
return self._find_state_transition_date([JobStatus.scheduled])
|
||||
|
||||
@property
|
||||
def started_at(self) -> datetime | None:
|
||||
return self._find_state_transition_date([JobStatus.running])
|
||||
|
||||
@property
|
||||
def completed_at(self) -> datetime | None:
|
||||
return self._find_state_transition_date(_COMPLETED_STATUSES)
|
||||
|
||||
@property
|
||||
def logs(self) -> list[LogMessage]:
|
||||
return self._logs[:]
|
||||
|
||||
def append_log(self, message: LogMessage) -> None:
|
||||
self._logs.append(message)
|
||||
|
||||
# TODO: implement
|
||||
def cancel(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _SchedulerBackend(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def shutdown(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def schedule(
|
||||
self,
|
||||
job: Job,
|
||||
on_log_message_cb: Callable[[str], None],
|
||||
on_status_change_cb: Callable[[JobStatus], None],
|
||||
on_artifact_collected_cb: Callable[[JobArtifact], None],
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _NaiveSchedulerBackend(_SchedulerBackend):
|
||||
def __init__(self, timeout: int = 5):
|
||||
self._timeout = timeout
|
||||
self._loop = asyncio.new_event_loop()
|
||||
# There may be performance implications of using threads due to Python
|
||||
# GIL; may need to measure if it's a real problem though
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self._loop.run_forever()
|
||||
|
||||
# When stopping the loop, give tasks a chance to finish
|
||||
# TODO: should we explicitly inform jobs of pending stoppage?
|
||||
for task in asyncio.all_tasks(self._loop):
|
||||
self._loop.run_until_complete(task)
|
||||
self._loop.close()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
self._thread.join()
|
||||
|
||||
# TODO: decouple scheduling and running the job
|
||||
def schedule(
|
||||
self,
|
||||
job: Job,
|
||||
on_log_message_cb: Callable[[str], None],
|
||||
on_status_change_cb: Callable[[JobStatus], None],
|
||||
on_artifact_collected_cb: Callable[[JobArtifact], None],
|
||||
) -> None:
|
||||
async def do():
|
||||
try:
|
||||
job.status = JobStatus.running
|
||||
await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb)
|
||||
except Exception as e:
|
||||
on_log_message_cb(str(e))
|
||||
job.status = JobStatus.failed
|
||||
logger.exception(f"Job {job.id} failed.")
|
||||
|
||||
asyncio.run_coroutine_threadsafe(do(), self._loop)
|
||||
|
||||
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
|
||||
pass
|
||||
|
||||
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
|
||||
pass
|
||||
|
||||
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
|
||||
pass
|
||||
|
||||
|
||||
_BACKENDS = {
|
||||
"naive": _NaiveSchedulerBackend,
|
||||
}
|
||||
|
||||
|
||||
def _get_backend_impl(backend: str) -> _SchedulerBackend:
|
||||
try:
|
||||
return _BACKENDS[backend]()
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Unknown backend {backend}") from e
|
||||
|
||||
|
||||
class Scheduler:
|
||||
def __init__(self, backend: str = "naive"):
|
||||
# TODO: if server crashes, job states are lost; we need to persist jobs on disc
|
||||
self._jobs: dict[JobID, Job] = {}
|
||||
self._backend = _get_backend_impl(backend)
|
||||
|
||||
def _on_log_message_cb(self, job: Job, message: str) -> None:
|
||||
msg = (datetime.now(timezone.utc), message)
|
||||
# At least for the time being, until there's a better way to expose
|
||||
# logs to users, log messages on console
|
||||
logger.info(f"Job {job.id}: {message}")
|
||||
job.append_log(msg)
|
||||
self._backend.on_log_message_cb(job, msg)
|
||||
|
||||
def _on_status_change_cb(self, job: Job, status: JobStatus) -> None:
|
||||
job.status = status
|
||||
self._backend.on_status_change_cb(job, status)
|
||||
|
||||
def _on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
|
||||
job.register_artifact(artifact)
|
||||
self._backend.on_artifact_collected_cb(job, artifact)
|
||||
|
||||
def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler) -> JobID:
|
||||
job = Job(type_, job_id, handler)
|
||||
if job.id in self._jobs:
|
||||
raise ValueError(f"Job {job.id} already exists")
|
||||
|
||||
self._jobs[job.id] = job
|
||||
job.status = JobStatus.scheduled
|
||||
self._backend.schedule(
|
||||
job,
|
||||
functools.partial(self._on_log_message_cb, job),
|
||||
functools.partial(self._on_status_change_cb, job),
|
||||
functools.partial(self._on_artifact_collected_cb, job),
|
||||
)
|
||||
|
||||
return job.id
|
||||
|
||||
def cancel(self, job_id: JobID) -> None:
|
||||
self.get_job(job_id).cancel()
|
||||
|
||||
def get_job(self, job_id: JobID) -> Job:
|
||||
try:
|
||||
return self._jobs[job_id]
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Job {job_id} not found") from e
|
||||
|
||||
def get_jobs(self, type_: JobType | None = None) -> list[Job]:
|
||||
jobs = list(self._jobs.values())
|
||||
if type_:
|
||||
jobs = [job for job in jobs if job._type == type_]
|
||||
return jobs
|
||||
|
||||
async def shutdown(self):
|
||||
# TODO: also cancel jobs once implemented
|
||||
await self._backend.shutdown()
|
|
@ -386,6 +386,16 @@ models:
|
|||
provider_id: groq
|
||||
provider_model_id: groq/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/llama-4-maverick-17b-128e-instruct
|
||||
provider_id: groq
|
||||
|
@ -396,6 +406,16 @@ models:
|
|||
provider_id: groq
|
||||
provider_model_id: groq/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
|
|
|
@ -158,6 +158,16 @@ models:
|
|||
provider_id: groq
|
||||
provider_model_id: groq/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/llama-4-maverick-17b-128e-instruct
|
||||
provider_id: groq
|
||||
|
@ -168,6 +178,16 @@ models:
|
|||
provider_id: groq
|
||||
provider_model_id: groq/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
|
|
|
@ -28,7 +28,7 @@ The following environment variables can be configured:
|
|||
|
||||
## Setting up vLLM server
|
||||
|
||||
In the following sections, we'll use either AMD and NVIDIA GPUs to serve as hardware accelerators for the vLLM
|
||||
In the following sections, we'll use AMD, NVIDIA or Intel GPUs to serve as hardware accelerators for the vLLM
|
||||
server, which acts as both the LLM inference provider and the safety provider. Note that vLLM also
|
||||
[supports many other hardware accelerators](https://docs.vllm.ai/en/latest/getting_started/installation.html) and
|
||||
that we only use GPUs here for demonstration purposes.
|
||||
|
@ -149,6 +149,55 @@ docker run \
|
|||
--port $SAFETY_PORT
|
||||
```
|
||||
|
||||
### Setting up vLLM server on Intel GPU
|
||||
|
||||
Refer to [vLLM Documentation for XPU](https://docs.vllm.ai/en/v0.8.2/getting_started/installation/gpu.html?device=xpu) to get a vLLM endpoint. In addition to vLLM side setup which guides towards installing vLLM from sources orself-building vLLM Docker container, Intel provides prebuilt vLLM container to use on systems with Intel GPUs supported by PyTorch XPU backend:
|
||||
- [intel/vllm](https://hub.docker.com/r/intel/vllm)
|
||||
|
||||
Here is a sample script to start a vLLM server locally via Docker using Intel provided container:
|
||||
|
||||
```bash
|
||||
export INFERENCE_PORT=8000
|
||||
export INFERENCE_MODEL=meta-llama/Llama-3.2-1B-Instruct
|
||||
export ZE_AFFINITY_MASK=0
|
||||
|
||||
docker run \
|
||||
--pull always \
|
||||
--device /dev/dri \
|
||||
-v /dev/dri/by-path:/dev/dri/by-path \
|
||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
|
||||
--env ZE_AFFINITY_MASK=$ZE_AFFINITY_MASK \
|
||||
-p $INFERENCE_PORT:$INFERENCE_PORT \
|
||||
--ipc=host \
|
||||
intel/vllm:xpu \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--model $INFERENCE_MODEL \
|
||||
--port $INFERENCE_PORT
|
||||
```
|
||||
|
||||
If you are using Llama Stack Safety / Shield APIs, then you will need to also run another instance of a vLLM with a corresponding safety model like `meta-llama/Llama-Guard-3-1B` using a script like:
|
||||
|
||||
```bash
|
||||
export SAFETY_PORT=8081
|
||||
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||
export ZE_AFFINITY_MASK=1
|
||||
|
||||
docker run \
|
||||
--pull always \
|
||||
--device /dev/dri \
|
||||
-v /dev/dri/by-path:/dev/dri/by-path \
|
||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
|
||||
--env ZE_AFFINITY_MASK=$ZE_AFFINITY_MASK \
|
||||
-p $SAFETY_PORT:$SAFETY_PORT \
|
||||
--ipc=host \
|
||||
intel/vllm:xpu \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--model $SAFETY_MODEL \
|
||||
--port $SAFETY_PORT
|
||||
```
|
||||
|
||||
## Running Llama Stack
|
||||
|
||||
Now you are ready to run Llama Stack with vLLM as the inference provider. You can do this via Conda (build code) or Docker which has a pre-built image.
|
||||
|
|
|
@ -474,6 +474,16 @@ models:
|
|||
provider_id: groq-openai-compat
|
||||
provider_model_id: groq/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
provider_id: groq-openai-compat
|
||||
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||
provider_id: groq-openai-compat
|
||||
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/llama-4-maverick-17b-128e-instruct
|
||||
provider_id: groq-openai-compat
|
||||
|
@ -484,6 +494,16 @@ models:
|
|||
provider_id: groq-openai-compat
|
||||
provider_model_id: groq/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
provider_id: groq-openai-compat
|
||||
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
|
||||
provider_id: groq-openai-compat
|
||||
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: Meta-Llama-3.1-8B-Instruct
|
||||
provider_id: sambanova-openai-compat
|
||||
|
|
|
@ -115,7 +115,7 @@ def test_openai_completion_streaming(openai_client, client_with_models, text_mod
|
|||
stream=True,
|
||||
max_tokens=50,
|
||||
)
|
||||
streamed_content = [chunk.choices[0].text for chunk in response]
|
||||
streamed_content = [chunk.choices[0].text or "" for chunk in response]
|
||||
content_str = "".join(streamed_content).lower().strip()
|
||||
assert len(content_str) > 10
|
||||
|
||||
|
|
|
@ -26,7 +26,12 @@ from openai.types.chat.chat_completion_chunk import (
|
|||
)
|
||||
from openai.types.model import Model as OpenAIModel
|
||||
|
||||
from llama_stack.apis.inference import ToolChoice, ToolConfig
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.models.llama.datatypes import StopReason
|
||||
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
|
||||
|
@ -232,3 +237,14 @@ def test_chat_completion_doesnt_block_event_loop(caplog):
|
|||
# above.
|
||||
asyncio_warnings = [record.message for record in caplog.records if record.name == "asyncio"]
|
||||
assert not asyncio_warnings
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_params_empty_tools(vllm_inference_adapter):
|
||||
request = ChatCompletionRequest(
|
||||
tools=[],
|
||||
model="test_model",
|
||||
messages=[UserMessage(content="test")],
|
||||
)
|
||||
params = await vllm_inference_adapter._get_params(request)
|
||||
assert "tools" not in params
|
||||
|
|
120
tests/unit/providers/utils/test_scheduler.py
Normal file
120
tests/unit/providers/utils/test_scheduler.py
Normal file
|
@ -0,0 +1,120 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.utils.scheduler import JobStatus, Scheduler
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scheduler_unknown_backend():
|
||||
with pytest.raises(ValueError):
|
||||
Scheduler(backend="unknown")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scheduler_naive():
|
||||
sched = Scheduler()
|
||||
|
||||
# make sure the scheduler starts empty
|
||||
with pytest.raises(ValueError):
|
||||
sched.get_job("unknown")
|
||||
assert sched.get_jobs() == []
|
||||
|
||||
called = False
|
||||
|
||||
# schedule a job that will exercise the handlers
|
||||
async def job_handler(on_log, on_status, on_artifact):
|
||||
nonlocal called
|
||||
called = True
|
||||
# exercise the handlers
|
||||
on_log("test log1")
|
||||
on_log("test log2")
|
||||
on_artifact({"type": "type1", "path": "path1"})
|
||||
on_artifact({"type": "type2", "path": "path2"})
|
||||
on_status(JobStatus.completed)
|
||||
|
||||
job_id = "test_job_id"
|
||||
job_type = "test_job_type"
|
||||
sched.schedule(job_type, job_id, job_handler)
|
||||
|
||||
# make sure the job was properly registered
|
||||
with pytest.raises(ValueError):
|
||||
sched.get_job("unknown")
|
||||
assert sched.get_job(job_id) is not None
|
||||
assert sched.get_jobs() == [sched.get_job(job_id)]
|
||||
|
||||
assert sched.get_jobs("unknown") == []
|
||||
assert sched.get_jobs(job_type) == [sched.get_job(job_id)]
|
||||
|
||||
# now shut the scheduler down and make sure the job ran
|
||||
await sched.shutdown()
|
||||
|
||||
assert called
|
||||
|
||||
job = sched.get_job(job_id)
|
||||
assert job is not None
|
||||
|
||||
assert job.status == JobStatus.completed
|
||||
|
||||
assert job.scheduled_at is not None
|
||||
assert job.started_at is not None
|
||||
assert job.completed_at is not None
|
||||
assert job.scheduled_at < job.started_at < job.completed_at
|
||||
|
||||
assert job.artifacts == [
|
||||
{"type": "type1", "path": "path1"},
|
||||
{"type": "type2", "path": "path2"},
|
||||
]
|
||||
assert [msg[1] for msg in job.logs] == ["test log1", "test log2"]
|
||||
assert job.logs[0][0] < job.logs[1][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scheduler_naive_handler_raises():
|
||||
sched = Scheduler()
|
||||
|
||||
async def failing_job_handler(on_log, on_status, on_artifact):
|
||||
on_status(JobStatus.running)
|
||||
raise ValueError("test error")
|
||||
|
||||
job_id = "test_job_id1"
|
||||
job_type = "test_job_type"
|
||||
sched.schedule(job_type, job_id, failing_job_handler)
|
||||
|
||||
job = sched.get_job(job_id)
|
||||
assert job is not None
|
||||
|
||||
# confirm the exception made the job transition to failed state, even
|
||||
# though it was set to `running` before the error
|
||||
for _ in range(10):
|
||||
if job.status == JobStatus.failed:
|
||||
break
|
||||
await asyncio.sleep(0.1)
|
||||
assert job.status == JobStatus.failed
|
||||
|
||||
# confirm that the raised error got registered in log
|
||||
assert job.logs[0][1] == "test error"
|
||||
|
||||
# even after failed job, we can schedule another one
|
||||
called = False
|
||||
|
||||
async def successful_job_handler(on_log, on_status, on_artifact):
|
||||
nonlocal called
|
||||
called = True
|
||||
on_status(JobStatus.completed)
|
||||
|
||||
job_id = "test_job_id2"
|
||||
sched.schedule(job_type, job_id, successful_job_handler)
|
||||
|
||||
await sched.shutdown()
|
||||
|
||||
assert called
|
||||
job = sched.get_job(job_id)
|
||||
assert job is not None
|
||||
assert job.status == JobStatus.completed
|
|
@ -1,6 +1,6 @@
|
|||
# Test Results Report
|
||||
|
||||
*Generated on: 2025-04-10 16:48:18*
|
||||
*Generated on: 2025-04-14 18:11:37*
|
||||
|
||||
*This report was generated by running `python tests/verifications/generate_report.py`*
|
||||
|
||||
|
@ -15,15 +15,15 @@
|
|||
|
||||
| Provider | Pass Rate | Tests Passed | Total Tests |
|
||||
| --- | --- | --- | --- |
|
||||
| Together | 64.7% | 22 | 34 |
|
||||
| Fireworks | 82.4% | 28 | 34 |
|
||||
| Openai | 100.0% | 24 | 24 |
|
||||
| Together | 48.7% | 37 | 76 |
|
||||
| Fireworks | 47.4% | 36 | 76 |
|
||||
| Openai | 100.0% | 52 | 52 |
|
||||
|
||||
|
||||
|
||||
## Together
|
||||
|
||||
*Tests run on: 2025-04-10 16:46:35*
|
||||
*Tests run on: 2025-04-14 18:08:14*
|
||||
|
||||
```bash
|
||||
# Run all tests for this provider:
|
||||
|
@ -48,19 +48,33 @@ pytest tests/verifications/openai_api/test_chat_completion.py --provider=togethe
|
|||
| test_chat_non_streaming_basic (earth) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_basic (saturn) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_image | ⚪ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (add_product_tool) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ❌ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ✅ | ❌ | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (text_then_weather_tool) | ❌ | ❌ | ❌ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (weather_tool_then_text) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_structured_output (calendar) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_structured_output (math) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_tool_calling | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_tool_choice_none | ❌ | ❌ | ❌ |
|
||||
| test_chat_non_streaming_tool_choice_required | ✅ | ✅ | ✅ |
|
||||
| test_chat_streaming_basic (earth) | ✅ | ❌ | ❌ |
|
||||
| test_chat_streaming_basic (saturn) | ✅ | ❌ | ❌ |
|
||||
| test_chat_streaming_image | ⚪ | ❌ | ❌ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (add_product_tool) | ✅ | ❌ | ❌ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ❌ | ❌ | ❌ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ❌ | ❌ | ❌ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (text_then_weather_tool) | ❌ | ❌ | ❌ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (weather_tool_then_text) | ❌ | ❌ | ❌ |
|
||||
| test_chat_streaming_structured_output (calendar) | ✅ | ❌ | ❌ |
|
||||
| test_chat_streaming_structured_output (math) | ✅ | ❌ | ❌ |
|
||||
| test_chat_streaming_tool_calling | ✅ | ❌ | ❌ |
|
||||
| test_chat_streaming_tool_choice_none | ❌ | ❌ | ❌ |
|
||||
| test_chat_streaming_tool_choice_required | ✅ | ❌ | ❌ |
|
||||
|
||||
## Fireworks
|
||||
|
||||
*Tests run on: 2025-04-10 16:44:44*
|
||||
*Tests run on: 2025-04-14 18:04:06*
|
||||
|
||||
```bash
|
||||
# Run all tests for this provider:
|
||||
|
@ -85,19 +99,33 @@ pytest tests/verifications/openai_api/test_chat_completion.py --provider=firewor
|
|||
| test_chat_non_streaming_basic (earth) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_basic (saturn) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_image | ⚪ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (add_product_tool) | ❌ | ❌ | ❌ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ❌ | ❌ | ❌ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ❌ | ❌ | ❌ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (text_then_weather_tool) | ❌ | ❌ | ❌ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (weather_tool_then_text) | ❌ | ❌ | ❌ |
|
||||
| test_chat_non_streaming_structured_output (calendar) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_structured_output (math) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_tool_calling | ❌ | ❌ | ❌ |
|
||||
| test_chat_non_streaming_tool_choice_none | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_tool_choice_required | ✅ | ❌ | ❌ |
|
||||
| test_chat_streaming_basic (earth) | ✅ | ✅ | ✅ |
|
||||
| test_chat_streaming_basic (saturn) | ✅ | ✅ | ✅ |
|
||||
| test_chat_streaming_image | ⚪ | ✅ | ✅ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (add_product_tool) | ❌ | ❌ | ❌ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ❌ | ❌ | ❌ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ❌ | ❌ | ❌ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (text_then_weather_tool) | ❌ | ❌ | ❌ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (weather_tool_then_text) | ❌ | ❌ | ❌ |
|
||||
| test_chat_streaming_structured_output (calendar) | ✅ | ✅ | ✅ |
|
||||
| test_chat_streaming_structured_output (math) | ✅ | ✅ | ✅ |
|
||||
| test_chat_streaming_tool_calling | ❌ | ❌ | ❌ |
|
||||
| test_chat_streaming_tool_choice_none | ✅ | ✅ | ✅ |
|
||||
| test_chat_streaming_tool_choice_required | ✅ | ❌ | ❌ |
|
||||
|
||||
## Openai
|
||||
|
||||
*Tests run on: 2025-04-10 16:47:28*
|
||||
*Tests run on: 2025-04-14 18:09:51*
|
||||
|
||||
```bash
|
||||
# Run all tests for this provider:
|
||||
|
@ -121,12 +149,26 @@ pytest tests/verifications/openai_api/test_chat_completion.py --provider=openai
|
|||
| test_chat_non_streaming_basic (earth) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_basic (saturn) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_image | ✅ | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (add_product_tool) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (text_then_weather_tool) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (weather_tool_then_text) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_structured_output (calendar) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_structured_output (math) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_tool_calling | ✅ | ✅ |
|
||||
| test_chat_non_streaming_tool_choice_none | ✅ | ✅ |
|
||||
| test_chat_non_streaming_tool_choice_required | ✅ | ✅ |
|
||||
| test_chat_streaming_basic (earth) | ✅ | ✅ |
|
||||
| test_chat_streaming_basic (saturn) | ✅ | ✅ |
|
||||
| test_chat_streaming_image | ✅ | ✅ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (add_product_tool) | ✅ | ✅ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ✅ | ✅ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ✅ | ✅ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (text_then_weather_tool) | ✅ | ✅ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (weather_tool_then_text) | ✅ | ✅ |
|
||||
| test_chat_streaming_structured_output (calendar) | ✅ | ✅ |
|
||||
| test_chat_streaming_structured_output (math) | ✅ | ✅ |
|
||||
| test_chat_streaming_tool_calling | ✅ | ✅ |
|
||||
| test_chat_streaming_tool_choice_none | ✅ | ✅ |
|
||||
| test_chat_streaming_tool_choice_required | ✅ | ✅ |
|
||||
|
|
14
tests/verifications/conf/fireworks-llama-stack.yaml
Normal file
14
tests/verifications/conf/fireworks-llama-stack.yaml
Normal file
|
@ -0,0 +1,14 @@
|
|||
base_url: http://localhost:8321/v1/openai/v1
|
||||
api_key_var: FIREWORKS_API_KEY
|
||||
models:
|
||||
- fireworks/llama-v3p3-70b-instruct
|
||||
- fireworks/llama4-scout-instruct-basic
|
||||
- fireworks/llama4-maverick-instruct-basic
|
||||
model_display_names:
|
||||
fireworks/llama-v3p3-70b-instruct: Llama-3.3-70B-Instruct
|
||||
fireworks/llama4-scout-instruct-basic: Llama-4-Scout-Instruct
|
||||
fireworks/llama4-maverick-instruct-basic: Llama-4-Maverick-Instruct
|
||||
test_exclusions:
|
||||
fireworks/llama-v3p3-70b-instruct:
|
||||
- test_chat_non_streaming_image
|
||||
- test_chat_streaming_image
|
14
tests/verifications/conf/groq-llama-stack.yaml
Normal file
14
tests/verifications/conf/groq-llama-stack.yaml
Normal file
|
@ -0,0 +1,14 @@
|
|||
base_url: http://localhost:8321/v1/openai/v1
|
||||
api_key_var: GROQ_API_KEY
|
||||
models:
|
||||
- groq/llama-3.3-70b-versatile
|
||||
- groq/llama-4-scout-17b-16e-instruct
|
||||
- groq/llama-4-maverick-17b-128e-instruct
|
||||
model_display_names:
|
||||
groq/llama-3.3-70b-versatile: Llama-3.3-70B-Instruct
|
||||
groq/llama-4-scout-17b-16e-instruct: Llama-4-Scout-Instruct
|
||||
groq/llama-4-maverick-17b-128e-instruct: Llama-4-Maverick-Instruct
|
||||
test_exclusions:
|
||||
groq/llama-3.3-70b-versatile:
|
||||
- test_chat_non_streaming_image
|
||||
- test_chat_streaming_image
|
|
@ -2,12 +2,12 @@ base_url: https://api.groq.com/openai/v1
|
|||
api_key_var: GROQ_API_KEY
|
||||
models:
|
||||
- llama-3.3-70b-versatile
|
||||
- llama-4-scout-17b-16e-instruct
|
||||
- llama-4-maverick-17b-128e-instruct
|
||||
- meta-llama/llama-4-scout-17b-16e-instruct
|
||||
- meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
model_display_names:
|
||||
llama-3.3-70b-versatile: Llama-3.3-70B-Instruct
|
||||
llama-4-scout-17b-16e-instruct: Llama-4-Scout-Instruct
|
||||
llama-4-maverick-17b-128e-instruct: Llama-4-Maverick-Instruct
|
||||
meta-llama/llama-4-scout-17b-16e-instruct: Llama-4-Scout-Instruct
|
||||
meta-llama/llama-4-maverick-17b-128e-instruct: Llama-4-Maverick-Instruct
|
||||
test_exclusions:
|
||||
llama-3.3-70b-versatile:
|
||||
- test_chat_non_streaming_image
|
||||
|
|
9
tests/verifications/conf/openai-llama-stack.yaml
Normal file
9
tests/verifications/conf/openai-llama-stack.yaml
Normal file
|
@ -0,0 +1,9 @@
|
|||
base_url: http://localhost:8321/v1/openai/v1
|
||||
api_key_var: OPENAI_API_KEY
|
||||
models:
|
||||
- openai/gpt-4o
|
||||
- openai/gpt-4o-mini
|
||||
model_display_names:
|
||||
openai/gpt-4o: gpt-4o
|
||||
openai/gpt-4o-mini: gpt-4o-mini
|
||||
test_exclusions: {}
|
14
tests/verifications/conf/together-llama-stack.yaml
Normal file
14
tests/verifications/conf/together-llama-stack.yaml
Normal file
|
@ -0,0 +1,14 @@
|
|||
base_url: http://localhost:8321/v1/openai/v1
|
||||
api_key_var: TOGETHER_API_KEY
|
||||
models:
|
||||
- together/meta-llama/Llama-3.3-70B-Instruct-Turbo
|
||||
- together/meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||
- together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8
|
||||
model_display_names:
|
||||
together/meta-llama/Llama-3.3-70B-Instruct-Turbo: Llama-3.3-70B-Instruct
|
||||
together/meta-llama/Llama-4-Scout-17B-16E-Instruct: Llama-4-Scout-Instruct
|
||||
together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8: Llama-4-Maverick-Instruct
|
||||
test_exclusions:
|
||||
together/meta-llama/Llama-3.3-70B-Instruct-Turbo:
|
||||
- test_chat_non_streaming_image
|
||||
- test_chat_streaming_image
|
|
@ -67,7 +67,17 @@ RESULTS_DIR.mkdir(exist_ok=True)
|
|||
# Maximum number of test result files to keep per provider
|
||||
MAX_RESULTS_PER_PROVIDER = 1
|
||||
|
||||
PROVIDER_ORDER = ["together", "fireworks", "groq", "cerebras", "openai"]
|
||||
PROVIDER_ORDER = [
|
||||
"together",
|
||||
"fireworks",
|
||||
"groq",
|
||||
"cerebras",
|
||||
"openai",
|
||||
"together-llama-stack",
|
||||
"fireworks-llama-stack",
|
||||
"groq-llama-stack",
|
||||
"openai-llama-stack",
|
||||
]
|
||||
|
||||
VERIFICATION_CONFIG = _load_all_verification_configs()
|
||||
|
||||
|
|
146
tests/verifications/openai-api-verification-run.yaml
Normal file
146
tests/verifications/openai-api-verification-run.yaml
Normal file
|
@ -0,0 +1,146 @@
|
|||
version: '2'
|
||||
image_name: openai-api-verification
|
||||
apis:
|
||||
- inference
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: together
|
||||
provider_type: remote::together
|
||||
config:
|
||||
url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY:}
|
||||
- provider_id: fireworks
|
||||
provider_type: remote::fireworks
|
||||
config:
|
||||
url: https://api.fireworks.ai/inference/v1
|
||||
api_key: ${env.FIREWORKS_API_KEY}
|
||||
- provider_id: groq
|
||||
provider_type: remote::groq
|
||||
config:
|
||||
url: https://api.groq.com
|
||||
api_key: ${env.GROQ_API_KEY}
|
||||
- provider_id: openai
|
||||
provider_type: remote::openai
|
||||
config:
|
||||
url: https://api.openai.com/v1
|
||||
api_key: ${env.OPENAI_API_KEY:}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
config: {}
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/openai}/faiss_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/openai/trace_store.db}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: code-interpreter
|
||||
provider_type: inline::code-interpreter
|
||||
config: {}
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
config: {}
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
config: {}
|
||||
- provider_id: wolfram-alpha
|
||||
provider_type: remote::wolfram-alpha
|
||||
config:
|
||||
api_key: ${env.WOLFRAM_ALPHA_API_KEY:}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/openai}/registry.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: together/meta-llama/Llama-3.3-70B-Instruct-Turbo
|
||||
provider_id: together
|
||||
provider_model_id: meta-llama/Llama-3.3-70B-Instruct-Turbo
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: together/meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||
provider_id: together
|
||||
provider_model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8
|
||||
provider_id: together
|
||||
provider_model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: fireworks/llama-v3p3-70b-instruct
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: fireworks/llama4-scout-instruct-basic
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama4-scout-instruct-basic
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: fireworks/llama4-maverick-instruct-basic
|
||||
provider_id: fireworks
|
||||
provider_model_id: accounts/fireworks/models/llama4-maverick-instruct-basic
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/llama-3.3-70b-versatile
|
||||
provider_id: groq
|
||||
provider_model_id: groq/llama-3.3-70b-versatile
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/llama-4-scout-17b-16e-instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/llama-4-maverick-17b-128e-instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: openai/gpt-4o
|
||||
provider_id: openai
|
||||
provider_model_id: openai/gpt-4o
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: openai/gpt-4o-mini
|
||||
provider_id: openai
|
||||
provider_model_id: openai/gpt-4o-mini
|
||||
model_type: llm
|
||||
shields: []
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
- toolgroup_id: builtin::code_interpreter
|
||||
provider_id: code-interpreter
|
||||
- toolgroup_id: builtin::wolfram_alpha
|
||||
provider_id: wolfram-alpha
|
||||
server:
|
||||
port: 8321
|
|
@ -99,6 +99,9 @@ def model_mapping(provider, providers_model_mapping):
|
|||
|
||||
@pytest.fixture
|
||||
def openai_client(base_url, api_key):
|
||||
# Simplify running against a local Llama Stack
|
||||
if "localhost" in base_url and not api_key:
|
||||
api_key = "empty"
|
||||
return OpenAI(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
|
|
|
@ -131,3 +131,221 @@ test_tool_calling:
|
|||
type: object
|
||||
type: function
|
||||
output: get_weather_tool_call
|
||||
|
||||
test_chat_multi_turn_tool_calling:
|
||||
test_name: test_chat_multi_turn_tool_calling
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "text_then_weather_tool"
|
||||
input:
|
||||
messages:
|
||||
- - role: user
|
||||
content: "What's the name of the Sun in latin?"
|
||||
- - role: user
|
||||
content: "What's the weather like in San Francisco?"
|
||||
tools:
|
||||
- function:
|
||||
description: Get the current weather
|
||||
name: get_weather
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
location:
|
||||
description: "The city and state (both required), e.g. San Francisco, CA."
|
||||
type: string
|
||||
required: ["location"]
|
||||
type: function
|
||||
tool_responses:
|
||||
- response: "{'response': '70 degrees and foggy'}"
|
||||
expected:
|
||||
- num_tool_calls: 0
|
||||
answer: ["sol"]
|
||||
- num_tool_calls: 1
|
||||
tool_name: get_weather
|
||||
tool_arguments:
|
||||
location: "San Francisco, CA"
|
||||
- num_tool_calls: 0
|
||||
answer: ["foggy", "70 degrees"]
|
||||
- case_id: "weather_tool_then_text"
|
||||
input:
|
||||
messages:
|
||||
- - role: user
|
||||
content: "What's the weather like in San Francisco?"
|
||||
tools:
|
||||
- function:
|
||||
description: Get the current weather
|
||||
name: get_weather
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
location:
|
||||
description: "The city and state (both required), e.g. San Francisco, CA."
|
||||
type: string
|
||||
required: ["location"]
|
||||
type: function
|
||||
tool_responses:
|
||||
- response: "{'response': '70 degrees and foggy'}"
|
||||
expected:
|
||||
- num_tool_calls: 1
|
||||
tool_name: get_weather
|
||||
tool_arguments:
|
||||
location: "San Francisco, CA"
|
||||
- num_tool_calls: 0
|
||||
answer: ["foggy", "70 degrees"]
|
||||
- case_id: "add_product_tool"
|
||||
input:
|
||||
messages:
|
||||
- - role: user
|
||||
content: "Please add a new product with name 'Widget', price 19.99, in stock, and tags ['new', 'sale'] and give me the product id."
|
||||
tools:
|
||||
- function:
|
||||
description: Add a new product
|
||||
name: addProduct
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
description: "Name of the product"
|
||||
type: string
|
||||
price:
|
||||
description: "Price of the product"
|
||||
type: number
|
||||
inStock:
|
||||
description: "Availability status of the product."
|
||||
type: boolean
|
||||
tags:
|
||||
description: "List of product tags"
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
required: ["name", "price", "inStock"]
|
||||
type: function
|
||||
tool_responses:
|
||||
- response: "{'response': 'Successfully added product with id: 123'}"
|
||||
expected:
|
||||
- num_tool_calls: 1
|
||||
tool_name: addProduct
|
||||
tool_arguments:
|
||||
name: "Widget"
|
||||
price: 19.99
|
||||
inStock: true
|
||||
tags:
|
||||
- "new"
|
||||
- "sale"
|
||||
- num_tool_calls: 0
|
||||
answer: ["123", "product id: 123"]
|
||||
- case_id: "get_then_create_event_tool"
|
||||
input:
|
||||
messages:
|
||||
- - role: system
|
||||
content: "Todays date is 2025-03-01."
|
||||
- role: user
|
||||
content: "Do i have any meetings on March 3rd at 10 am? Yes or no?"
|
||||
- - role: user
|
||||
content: "Alright then, Create an event named 'Team Building', scheduled for that time same time, in the 'Main Conference Room' and add Alice, Bob, Charlie to it. Give me the created event id."
|
||||
tools:
|
||||
- function:
|
||||
description: Create a new event
|
||||
name: create_event
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
description: "Name of the event"
|
||||
type: string
|
||||
date:
|
||||
description: "Date of the event in ISO format"
|
||||
type: string
|
||||
time:
|
||||
description: "Event Time (HH:MM)"
|
||||
type: string
|
||||
location:
|
||||
description: "Location of the event"
|
||||
type: string
|
||||
participants:
|
||||
description: "List of participant names"
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
required: ["name", "date", "time", "location", "participants"]
|
||||
type: function
|
||||
- function:
|
||||
description: Get an event by date and time
|
||||
name: get_event
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
date:
|
||||
description: "Date of the event in ISO format"
|
||||
type: string
|
||||
time:
|
||||
description: "Event Time (HH:MM)"
|
||||
type: string
|
||||
required: ["date", "time"]
|
||||
type: function
|
||||
tool_responses:
|
||||
- response: "{'response': 'No events found for 2025-03-03 at 10:00'}"
|
||||
- response: "{'response': 'Successfully created new event with id: e_123'}"
|
||||
expected:
|
||||
- num_tool_calls: 1
|
||||
tool_name: get_event
|
||||
tool_arguments:
|
||||
date: "2025-03-03"
|
||||
time: "10:00"
|
||||
- num_tool_calls: 0
|
||||
answer: ["no", "no events found", "no meetings"]
|
||||
- num_tool_calls: 1
|
||||
tool_name: create_event
|
||||
tool_arguments:
|
||||
name: "Team Building"
|
||||
date: "2025-03-03"
|
||||
time: "10:00"
|
||||
location: "Main Conference Room"
|
||||
participants:
|
||||
- "Alice"
|
||||
- "Bob"
|
||||
- "Charlie"
|
||||
- num_tool_calls: 0
|
||||
answer: ["e_123", "event id: e_123"]
|
||||
- case_id: "compare_monthly_expense_tool"
|
||||
input:
|
||||
messages:
|
||||
- - role: system
|
||||
content: "Todays date is 2025-03-01."
|
||||
- role: user
|
||||
content: "what was my monthly expense in Jan of this year?"
|
||||
- - role: user
|
||||
content: "Was it less than Feb of last year? Only answer with yes or no."
|
||||
tools:
|
||||
- function:
|
||||
description: Get monthly expense summary
|
||||
name: getMonthlyExpenseSummary
|
||||
parameters:
|
||||
type: object
|
||||
properties:
|
||||
month:
|
||||
description: "Month of the year (1-12)"
|
||||
type: integer
|
||||
year:
|
||||
description: "Year"
|
||||
type: integer
|
||||
required: ["month", "year"]
|
||||
type: function
|
||||
tool_responses:
|
||||
- response: "{'response': 'Total expenses for January 2025: $1000'}"
|
||||
- response: "{'response': 'Total expenses for February 2024: $2000'}"
|
||||
expected:
|
||||
- num_tool_calls: 1
|
||||
tool_name: getMonthlyExpenseSummary
|
||||
tool_arguments:
|
||||
month: 1
|
||||
year: 2025
|
||||
- num_tool_calls: 0
|
||||
answer: ["1000", "$1,000", "1,000"]
|
||||
- num_tool_calls: 1
|
||||
tool_name: getMonthlyExpenseSummary
|
||||
tool_arguments:
|
||||
month: 2
|
||||
year: 2024
|
||||
- num_tool_calls: 0
|
||||
answer: ["yes"]
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
@ -243,43 +244,294 @@ def test_chat_streaming_tool_calling(request, openai_client, model, provider, ve
|
|||
stream=True,
|
||||
)
|
||||
|
||||
# Accumulate partial tool_calls here
|
||||
tool_calls_buffer = {}
|
||||
current_id = None
|
||||
# Process streaming chunks
|
||||
for chunk in stream:
|
||||
choice = chunk.choices[0]
|
||||
delta = choice.delta
|
||||
|
||||
if delta.tool_calls is None:
|
||||
continue
|
||||
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
if tool_call_delta.id:
|
||||
current_id = tool_call_delta.id
|
||||
call_id = current_id
|
||||
func_delta = tool_call_delta.function
|
||||
|
||||
if call_id not in tool_calls_buffer:
|
||||
tool_calls_buffer[call_id] = {
|
||||
"id": call_id,
|
||||
"type": tool_call_delta.type,
|
||||
"name": func_delta.name,
|
||||
"arguments": "",
|
||||
}
|
||||
|
||||
if func_delta.arguments:
|
||||
tool_calls_buffer[call_id]["arguments"] += func_delta.arguments
|
||||
|
||||
_, tool_calls_buffer = _accumulate_streaming_tool_calls(stream)
|
||||
assert len(tool_calls_buffer) == 1
|
||||
for call in tool_calls_buffer.values():
|
||||
for call in tool_calls_buffer:
|
||||
assert len(call["id"]) > 0
|
||||
assert call["name"] == "get_weather"
|
||||
function = call["function"]
|
||||
assert function["name"] == "get_weather"
|
||||
|
||||
args_dict = json.loads(call["arguments"])
|
||||
args_dict = json.loads(function["arguments"])
|
||||
assert "san francisco" in args_dict["location"].lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], # Reusing existing case for now
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_chat_non_streaming_tool_choice_required(request, openai_client, model, provider, verification_config, case):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
response = openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=case["input"]["messages"],
|
||||
tools=case["input"]["tools"],
|
||||
tool_choice="required", # Force tool call
|
||||
stream=False,
|
||||
)
|
||||
print(response)
|
||||
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
assert len(response.choices[0].message.tool_calls) > 0, "Expected tool call when tool_choice='required'"
|
||||
expected_tool_name = case["input"]["tools"][0]["function"]["name"]
|
||||
assert response.choices[0].message.tool_calls[0].function.name == expected_tool_name
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], # Reusing existing case for now
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_chat_streaming_tool_choice_required(request, openai_client, model, provider, verification_config, case):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
stream = openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=case["input"]["messages"],
|
||||
tools=case["input"]["tools"],
|
||||
tool_choice="required", # Force tool call
|
||||
stream=True,
|
||||
)
|
||||
|
||||
_, tool_calls_buffer = _accumulate_streaming_tool_calls(stream)
|
||||
|
||||
assert len(tool_calls_buffer) > 0, "Expected tool call when tool_choice='required'"
|
||||
expected_tool_name = case["input"]["tools"][0]["function"]["name"]
|
||||
assert any(call["function"]["name"] == expected_tool_name for call in tool_calls_buffer), (
|
||||
f"Expected tool call '{expected_tool_name}' not found in stream"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], # Reusing existing case for now
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_chat_non_streaming_tool_choice_none(request, openai_client, model, provider, verification_config, case):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
response = openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=case["input"]["messages"],
|
||||
tools=case["input"]["tools"],
|
||||
tool_choice="none",
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
assert response.choices[0].message.tool_calls is None, "Expected no tool calls when tool_choice='none'"
|
||||
assert response.choices[0].message.content is not None, "Expected content when tool_choice='none'"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], # Reusing existing case for now
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_chat_streaming_tool_choice_none(request, openai_client, model, provider, verification_config, case):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
stream = openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=case["input"]["messages"],
|
||||
tools=case["input"]["tools"],
|
||||
tool_choice="none",
|
||||
stream=True,
|
||||
)
|
||||
|
||||
content = ""
|
||||
for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
if delta.content:
|
||||
content += delta.content
|
||||
assert not delta.tool_calls, "Expected no tool call chunks when tool_choice='none'"
|
||||
|
||||
assert len(content) > 0, "Expected content when tool_choice='none'"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
chat_completion_test_cases.get("test_chat_multi_turn_tool_calling", {}).get("test_params", {}).get("case", []),
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_chat_non_streaming_multi_turn_tool_calling(request, openai_client, model, provider, verification_config, case):
|
||||
"""
|
||||
Test cases for multi-turn tool calling.
|
||||
Tool calls are asserted.
|
||||
Tool responses are provided in the test case.
|
||||
Final response is asserted.
|
||||
"""
|
||||
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
# Create a copy of the messages list to avoid modifying the original
|
||||
messages = []
|
||||
tools = case["input"]["tools"]
|
||||
# Use deepcopy to prevent modification across runs/parametrization
|
||||
expected_results = copy.deepcopy(case["expected"])
|
||||
tool_responses = copy.deepcopy(case.get("tool_responses", []))
|
||||
input_messages_turns = copy.deepcopy(case["input"]["messages"])
|
||||
|
||||
# keep going until either
|
||||
# 1. we have messages to test in multi-turn
|
||||
# 2. no messages but last message is tool response
|
||||
while len(input_messages_turns) > 0 or (len(messages) > 0 and messages[-1]["role"] == "tool"):
|
||||
# do not take new messages if last message is tool response
|
||||
if len(messages) == 0 or messages[-1]["role"] != "tool":
|
||||
new_messages = input_messages_turns.pop(0)
|
||||
# Ensure new_messages is a list of message objects
|
||||
if isinstance(new_messages, list):
|
||||
messages.extend(new_messages)
|
||||
else:
|
||||
# If it's a single message object, add it directly
|
||||
messages.append(new_messages)
|
||||
|
||||
# --- API Call ---
|
||||
response = openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# --- Process Response ---
|
||||
assistant_message = response.choices[0].message
|
||||
messages.append(assistant_message.model_dump(exclude_unset=True))
|
||||
|
||||
assert assistant_message.role == "assistant"
|
||||
|
||||
# Get the expected result data
|
||||
expected = expected_results.pop(0)
|
||||
num_tool_calls = expected["num_tool_calls"]
|
||||
|
||||
# --- Assertions based on expected result ---
|
||||
assert len(assistant_message.tool_calls or []) == num_tool_calls, (
|
||||
f"Expected {num_tool_calls} tool calls, but got {len(assistant_message.tool_calls or [])}"
|
||||
)
|
||||
|
||||
if num_tool_calls > 0:
|
||||
tool_call = assistant_message.tool_calls[0]
|
||||
assert tool_call.function.name == expected["tool_name"], (
|
||||
f"Expected tool '{expected['tool_name']}', got '{tool_call.function.name}'"
|
||||
)
|
||||
# Parse the JSON string arguments before comparing
|
||||
actual_arguments = json.loads(tool_call.function.arguments)
|
||||
assert actual_arguments == expected["tool_arguments"], (
|
||||
f"Expected arguments '{expected['tool_arguments']}', got '{actual_arguments}'"
|
||||
)
|
||||
|
||||
# Prepare and append the tool response for the next turn
|
||||
tool_response = tool_responses.pop(0)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": tool_response["response"],
|
||||
}
|
||||
)
|
||||
else:
|
||||
assert assistant_message.content is not None, "Expected content, but none received."
|
||||
expected_answers = expected["answer"] # This is now a list
|
||||
content_lower = assistant_message.content.lower()
|
||||
assert any(ans.lower() in content_lower for ans in expected_answers), (
|
||||
f"Expected one of {expected_answers} in content, but got: '{assistant_message.content}'"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
chat_completion_test_cases.get("test_chat_multi_turn_tool_calling", {}).get("test_params", {}).get("case", []),
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_chat_streaming_multi_turn_tool_calling(request, openai_client, model, provider, verification_config, case):
|
||||
""" """
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
messages = []
|
||||
tools = case["input"]["tools"]
|
||||
expected_results = copy.deepcopy(case["expected"])
|
||||
tool_responses = copy.deepcopy(case.get("tool_responses", []))
|
||||
input_messages_turns = copy.deepcopy(case["input"]["messages"])
|
||||
|
||||
while len(input_messages_turns) > 0 or (len(messages) > 0 and messages[-1]["role"] == "tool"):
|
||||
if len(messages) == 0 or messages[-1]["role"] != "tool":
|
||||
new_messages = input_messages_turns.pop(0)
|
||||
if isinstance(new_messages, list):
|
||||
messages.extend(new_messages)
|
||||
else:
|
||||
messages.append(new_messages)
|
||||
|
||||
# --- API Call (Streaming) ---
|
||||
stream = openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# --- Process Stream ---
|
||||
accumulated_content, accumulated_tool_calls = _accumulate_streaming_tool_calls(stream)
|
||||
|
||||
# --- Construct Assistant Message for History ---
|
||||
assistant_message_dict = {"role": "assistant"}
|
||||
if accumulated_content:
|
||||
assistant_message_dict["content"] = accumulated_content
|
||||
if accumulated_tool_calls:
|
||||
assistant_message_dict["tool_calls"] = accumulated_tool_calls
|
||||
|
||||
messages.append(assistant_message_dict)
|
||||
|
||||
# --- Assertions ---
|
||||
expected = expected_results.pop(0)
|
||||
num_tool_calls = expected["num_tool_calls"]
|
||||
|
||||
assert len(accumulated_tool_calls or []) == num_tool_calls, (
|
||||
f"Expected {num_tool_calls} tool calls, but got {len(accumulated_tool_calls or [])}"
|
||||
)
|
||||
|
||||
if num_tool_calls > 0:
|
||||
# Use the first accumulated tool call for assertion
|
||||
tool_call = accumulated_tool_calls[0]
|
||||
assert tool_call["function"]["name"] == expected["tool_name"], (
|
||||
f"Expected tool '{expected['tool_name']}', got '{tool_call['function']['name']}'"
|
||||
)
|
||||
# Parse the accumulated arguments string for comparison
|
||||
actual_arguments = json.loads(tool_call["function"]["arguments"])
|
||||
assert actual_arguments == expected["tool_arguments"], (
|
||||
f"Expected arguments '{expected['tool_arguments']}', got '{actual_arguments}'"
|
||||
)
|
||||
|
||||
# Prepare and append the tool response for the next turn
|
||||
tool_response = tool_responses.pop(0)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call["id"],
|
||||
"content": tool_response["response"],
|
||||
}
|
||||
)
|
||||
else:
|
||||
assert accumulated_content is not None and accumulated_content != "", "Expected content, but none received."
|
||||
expected_answers = expected["answer"]
|
||||
content_lower = accumulated_content.lower()
|
||||
assert any(ans.lower() in content_lower for ans in expected_answers), (
|
||||
f"Expected one of {expected_answers} in content, but got: '{accumulated_content}'"
|
||||
)
|
||||
|
||||
|
||||
# --- Helper functions (structured output validation) ---
|
||||
|
||||
|
||||
|
@ -324,3 +576,47 @@ def validate_structured_output(maybe_json_content: str, schema_name: str) -> Non
|
|||
assert len(structured_output.participants) == 2
|
||||
elif schema_name == "valid_math_reasoning":
|
||||
assert len(structured_output.final_answer) > 0
|
||||
|
||||
|
||||
def _accumulate_streaming_tool_calls(stream):
|
||||
"""Accumulates tool calls and content from a streaming ChatCompletion response."""
|
||||
tool_calls_buffer = {}
|
||||
current_id = None
|
||||
full_content = "" # Initialize content accumulator
|
||||
# Process streaming chunks
|
||||
for chunk in stream:
|
||||
choice = chunk.choices[0]
|
||||
delta = choice.delta
|
||||
|
||||
# Accumulate content
|
||||
if delta.content:
|
||||
full_content += delta.content
|
||||
|
||||
if delta.tool_calls is None:
|
||||
continue
|
||||
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
if tool_call_delta.id:
|
||||
current_id = tool_call_delta.id
|
||||
call_id = current_id
|
||||
# Skip if no ID seen yet for this tool call delta
|
||||
if not call_id:
|
||||
continue
|
||||
func_delta = tool_call_delta.function
|
||||
|
||||
if call_id not in tool_calls_buffer:
|
||||
tool_calls_buffer[call_id] = {
|
||||
"id": call_id,
|
||||
"type": "function", # Assume function type
|
||||
"function": {"name": None, "arguments": ""}, # Nested structure
|
||||
}
|
||||
|
||||
# Accumulate name and arguments into the nested function dict
|
||||
if func_delta:
|
||||
if func_delta.name:
|
||||
tool_calls_buffer[call_id]["function"]["name"] = func_delta.name
|
||||
if func_delta.arguments:
|
||||
tool_calls_buffer[call_id]["function"]["arguments"] += func_delta.arguments
|
||||
|
||||
# Return content and tool calls as a list
|
||||
return full_content, list(tool_calls_buffer.values())
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load diff
File diff suppressed because one or more lines are too long
Loading…
Add table
Add a link
Reference in a new issue