mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 14:08:00 +00:00
Merge branch 'main' into patch-1
This commit is contained in:
commit
92c2edd61c
35 changed files with 1916 additions and 1589 deletions
38
.github/workflows/pre-commit.yml
vendored
38
.github/workflows/pre-commit.yml
vendored
|
@ -14,10 +14,18 @@ concurrency:
|
||||||
jobs:
|
jobs:
|
||||||
pre-commit:
|
pre-commit:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
pull-requests: write
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
with:
|
||||||
|
# For dependabot PRs, we need to checkout with a token that can push changes
|
||||||
|
token: ${{ github.actor == 'dependabot[bot]' && secrets.GITHUB_TOKEN || github.token }}
|
||||||
|
# Fetch full history for dependabot PRs to allow commits
|
||||||
|
fetch-depth: ${{ github.actor == 'dependabot[bot]' && 0 || 1 }}
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||||
|
@ -29,15 +37,45 @@ jobs:
|
||||||
.pre-commit-config.yaml
|
.pre-commit-config.yaml
|
||||||
|
|
||||||
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
||||||
|
continue-on-error: true
|
||||||
env:
|
env:
|
||||||
SKIP: no-commit-to-branch
|
SKIP: no-commit-to-branch
|
||||||
RUFF_OUTPUT_FORMAT: github
|
RUFF_OUTPUT_FORMAT: github
|
||||||
|
|
||||||
|
- name: Debug
|
||||||
|
run: |
|
||||||
|
echo "github.ref: ${{ github.ref }}"
|
||||||
|
echo "github.actor: ${{ github.actor }}"
|
||||||
|
|
||||||
|
- name: Commit changes for dependabot PRs
|
||||||
|
if: github.actor == 'dependabot[bot]'
|
||||||
|
run: |
|
||||||
|
if ! git diff --exit-code || [ -n "$(git ls-files --others --exclude-standard)" ]; then
|
||||||
|
git config --local user.email "github-actions[bot]@users.noreply.github.com"
|
||||||
|
git config --local user.name "github-actions[bot]"
|
||||||
|
|
||||||
|
# Ensure we're on the correct branch
|
||||||
|
git checkout -B ${{ github.head_ref }}
|
||||||
|
git add -A
|
||||||
|
git commit -m "Apply pre-commit fixes"
|
||||||
|
|
||||||
|
# Pull latest changes from the PR branch and rebase our commit on top
|
||||||
|
git pull --rebase origin ${{ github.head_ref }}
|
||||||
|
|
||||||
|
# Push to the PR branch
|
||||||
|
git push origin ${{ github.head_ref }}
|
||||||
|
echo "Pre-commit fixes committed and pushed"
|
||||||
|
else
|
||||||
|
echo "No changes to commit"
|
||||||
|
fi
|
||||||
|
|
||||||
- name: Verify if there are any diff files after pre-commit
|
- name: Verify if there are any diff files after pre-commit
|
||||||
|
if: github.actor != 'dependabot[bot]'
|
||||||
run: |
|
run: |
|
||||||
git diff --exit-code || (echo "There are uncommitted changes, run pre-commit locally and commit again" && exit 1)
|
git diff --exit-code || (echo "There are uncommitted changes, run pre-commit locally and commit again" && exit 1)
|
||||||
|
|
||||||
- name: Verify if there are any new files after pre-commit
|
- name: Verify if there are any new files after pre-commit
|
||||||
|
if: github.actor != 'dependabot[bot]'
|
||||||
run: |
|
run: |
|
||||||
unstaged_files=$(git ls-files --others --exclude-standard)
|
unstaged_files=$(git ls-files --others --exclude-standard)
|
||||||
if [ -n "$unstaged_files" ]; then
|
if [ -n "$unstaged_files" ]; then
|
||||||
|
|
|
@ -13,6 +13,7 @@ on:
|
||||||
- 'uv.lock'
|
- 'uv.lock'
|
||||||
- 'pyproject.toml'
|
- 'pyproject.toml'
|
||||||
- 'requirements.txt'
|
- 'requirements.txt'
|
||||||
|
- 'tests/external/*'
|
||||||
- '.github/workflows/test-external-provider-module.yml' # This workflow
|
- '.github/workflows/test-external-provider-module.yml' # This workflow
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|
1
.github/workflows/test-external.yml
vendored
1
.github/workflows/test-external.yml
vendored
|
@ -13,6 +13,7 @@ on:
|
||||||
- 'uv.lock'
|
- 'uv.lock'
|
||||||
- 'pyproject.toml'
|
- 'pyproject.toml'
|
||||||
- 'requirements.txt'
|
- 'requirements.txt'
|
||||||
|
- 'tests/external/*'
|
||||||
- '.github/workflows/test-external.yml' # This workflow
|
- '.github/workflows/test-external.yml' # This workflow
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|
2
.github/workflows/unit-tests.yml
vendored
2
.github/workflows/unit-tests.yml
vendored
|
@ -35,6 +35,8 @@ jobs:
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
uses: ./.github/actions/setup-runner
|
uses: ./.github/actions/setup-runner
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python }}
|
||||||
|
|
||||||
- name: Run unit tests
|
- name: Run unit tests
|
||||||
run: |
|
run: |
|
||||||
|
|
|
@ -19,7 +19,6 @@ repos:
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
args: ["--unsafe"]
|
args: ["--unsafe"]
|
||||||
- id: detect-private-key
|
- id: detect-private-key
|
||||||
- id: requirements-txt-fixer
|
|
||||||
- id: mixed-line-ending
|
- id: mixed-line-ending
|
||||||
args: [--fix=lf] # Forces to replace line ending by LF (line feed)
|
args: [--fix=lf] # Forces to replace line ending by LF (line feed)
|
||||||
- id: check-executables-have-shebangs
|
- id: check-executables-have-shebangs
|
||||||
|
@ -56,14 +55,6 @@ repos:
|
||||||
rev: 0.7.20
|
rev: 0.7.20
|
||||||
hooks:
|
hooks:
|
||||||
- id: uv-lock
|
- id: uv-lock
|
||||||
- id: uv-export
|
|
||||||
args: [
|
|
||||||
"--frozen",
|
|
||||||
"--no-hashes",
|
|
||||||
"--no-emit-project",
|
|
||||||
"--no-default-groups",
|
|
||||||
"--output-file=requirements.txt"
|
|
||||||
]
|
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v1.16.1
|
rev: v1.16.1
|
||||||
|
|
49
docs/_static/llama-stack-spec.html
vendored
49
docs/_static/llama-stack-spec.html
vendored
|
@ -9770,7 +9770,7 @@
|
||||||
{
|
{
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
|
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -9821,13 +9821,17 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
|
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIFile"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"discriminator": {
|
"discriminator": {
|
||||||
"propertyName": "type",
|
"propertyName": "type",
|
||||||
"mapping": {
|
"mapping": {
|
||||||
"text": "#/components/schemas/OpenAIChatCompletionContentPartTextParam",
|
"text": "#/components/schemas/OpenAIChatCompletionContentPartTextParam",
|
||||||
"image_url": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
|
"image_url": "#/components/schemas/OpenAIChatCompletionContentPartImageParam",
|
||||||
|
"file": "#/components/schemas/OpenAIFile"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -9955,7 +9959,7 @@
|
||||||
{
|
{
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
|
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -9974,6 +9978,41 @@
|
||||||
"title": "OpenAIDeveloperMessageParam",
|
"title": "OpenAIDeveloperMessageParam",
|
||||||
"description": "A message from the developer in an OpenAI-compatible chat completion request."
|
"description": "A message from the developer in an OpenAI-compatible chat completion request."
|
||||||
},
|
},
|
||||||
|
"OpenAIFile": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "file",
|
||||||
|
"default": "file"
|
||||||
|
},
|
||||||
|
"file": {
|
||||||
|
"$ref": "#/components/schemas/OpenAIFileFile"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"type",
|
||||||
|
"file"
|
||||||
|
],
|
||||||
|
"title": "OpenAIFile"
|
||||||
|
},
|
||||||
|
"OpenAIFileFile": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"file_data": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"file_id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"filename": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"title": "OpenAIFileFile"
|
||||||
|
},
|
||||||
"OpenAIImageURL": {
|
"OpenAIImageURL": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -10036,7 +10075,7 @@
|
||||||
{
|
{
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
|
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -10107,7 +10146,7 @@
|
||||||
{
|
{
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
|
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
35
docs/_static/llama-stack-spec.yaml
vendored
35
docs/_static/llama-stack-spec.yaml
vendored
|
@ -6895,7 +6895,7 @@ components:
|
||||||
- type: string
|
- type: string
|
||||||
- type: array
|
- type: array
|
||||||
items:
|
items:
|
||||||
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
|
$ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
||||||
description: The content of the model's response
|
description: The content of the model's response
|
||||||
name:
|
name:
|
||||||
type: string
|
type: string
|
||||||
|
@ -6934,11 +6934,13 @@ components:
|
||||||
oneOf:
|
oneOf:
|
||||||
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
||||||
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
|
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
|
||||||
|
- $ref: '#/components/schemas/OpenAIFile'
|
||||||
discriminator:
|
discriminator:
|
||||||
propertyName: type
|
propertyName: type
|
||||||
mapping:
|
mapping:
|
||||||
text: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
text: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
||||||
image_url: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
|
image_url: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
|
||||||
|
file: '#/components/schemas/OpenAIFile'
|
||||||
OpenAIChatCompletionContentPartTextParam:
|
OpenAIChatCompletionContentPartTextParam:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -7037,7 +7039,7 @@ components:
|
||||||
- type: string
|
- type: string
|
||||||
- type: array
|
- type: array
|
||||||
items:
|
items:
|
||||||
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
|
$ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
||||||
description: The content of the developer message
|
description: The content of the developer message
|
||||||
name:
|
name:
|
||||||
type: string
|
type: string
|
||||||
|
@ -7050,6 +7052,31 @@ components:
|
||||||
title: OpenAIDeveloperMessageParam
|
title: OpenAIDeveloperMessageParam
|
||||||
description: >-
|
description: >-
|
||||||
A message from the developer in an OpenAI-compatible chat completion request.
|
A message from the developer in an OpenAI-compatible chat completion request.
|
||||||
|
OpenAIFile:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: file
|
||||||
|
default: file
|
||||||
|
file:
|
||||||
|
$ref: '#/components/schemas/OpenAIFileFile'
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- type
|
||||||
|
- file
|
||||||
|
title: OpenAIFile
|
||||||
|
OpenAIFileFile:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
file_data:
|
||||||
|
type: string
|
||||||
|
file_id:
|
||||||
|
type: string
|
||||||
|
filename:
|
||||||
|
type: string
|
||||||
|
additionalProperties: false
|
||||||
|
title: OpenAIFileFile
|
||||||
OpenAIImageURL:
|
OpenAIImageURL:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -7090,7 +7117,7 @@ components:
|
||||||
- type: string
|
- type: string
|
||||||
- type: array
|
- type: array
|
||||||
items:
|
items:
|
||||||
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
|
$ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
||||||
description: >-
|
description: >-
|
||||||
The content of the "system prompt". If multiple system messages are provided,
|
The content of the "system prompt". If multiple system messages are provided,
|
||||||
they are concatenated. The underlying Llama Stack code may also add other
|
they are concatenated. The underlying Llama Stack code may also add other
|
||||||
|
@ -7148,7 +7175,7 @@ components:
|
||||||
- type: string
|
- type: string
|
||||||
- type: array
|
- type: array
|
||||||
items:
|
items:
|
||||||
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
|
$ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
||||||
description: The response content from the tool
|
description: The response content from the tool
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
|
|
|
@ -13,7 +13,7 @@ llama stack build --template starter --image-type venv
|
||||||
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||||
|
|
||||||
client = LlamaStackAsLibraryClient(
|
client = LlamaStackAsLibraryClient(
|
||||||
"ollama",
|
"starter",
|
||||||
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
|
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
|
||||||
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
|
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
|
||||||
)
|
)
|
||||||
|
|
|
@ -12,8 +12,7 @@ To enable external providers, you need to add `module` into your build yaml, all
|
||||||
an example entry in your build.yaml should look like:
|
an example entry in your build.yaml should look like:
|
||||||
|
|
||||||
```
|
```
|
||||||
- provider_id: ramalama
|
- provider_type: remote::ramalama
|
||||||
provider_type: remote::ramalama
|
|
||||||
module: ramalama_stack
|
module: ramalama_stack
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -255,8 +254,7 @@ distribution_spec:
|
||||||
container_image: null
|
container_image: null
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- provider_id: ramalama
|
- provider_type: remote::ramalama
|
||||||
provider_type: remote::ramalama
|
|
||||||
module: ramalama_stack==0.3.0a0
|
module: ramalama_stack==0.3.0a0
|
||||||
image_type: venv
|
image_type: venv
|
||||||
image_name: null
|
image_name: null
|
||||||
|
|
|
@ -9,11 +9,13 @@ OpenAI inference provider for accessing GPT models and other OpenAI services.
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `api_key` | `str \| None` | No | | API key for OpenAI models |
|
| `api_key` | `str \| None` | No | | API key for OpenAI models |
|
||||||
|
| `base_url` | `<class 'str'>` | No | https://api.openai.com/v1 | Base URL for OpenAI API |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
api_key: ${env.OPENAI_API_KEY:=}
|
api_key: ${env.OPENAI_API_KEY:=}
|
||||||
|
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -455,8 +455,21 @@ class OpenAIChatCompletionContentPartImageParam(BaseModel):
|
||||||
image_url: OpenAIImageURL
|
image_url: OpenAIImageURL
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIFileFile(BaseModel):
|
||||||
|
file_data: str | None = None
|
||||||
|
file_id: str | None = None
|
||||||
|
filename: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIFile(BaseModel):
|
||||||
|
type: Literal["file"] = "file"
|
||||||
|
file: OpenAIFileFile
|
||||||
|
|
||||||
|
|
||||||
OpenAIChatCompletionContentPartParam = Annotated[
|
OpenAIChatCompletionContentPartParam = Annotated[
|
||||||
OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam | OpenAIFile,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
|
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
|
||||||
|
@ -464,6 +477,8 @@ register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletion
|
||||||
|
|
||||||
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
|
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
|
||||||
|
|
||||||
|
OpenAIChatCompletionTextOnlyMessageContent = str | list[OpenAIChatCompletionContentPartTextParam]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIUserMessageParam(BaseModel):
|
class OpenAIUserMessageParam(BaseModel):
|
||||||
|
@ -489,7 +504,7 @@ class OpenAISystemMessageParam(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
role: Literal["system"] = "system"
|
role: Literal["system"] = "system"
|
||||||
content: OpenAIChatCompletionMessageContent
|
content: OpenAIChatCompletionTextOnlyMessageContent
|
||||||
name: str | None = None
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -518,7 +533,7 @@ class OpenAIAssistantMessageParam(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
role: Literal["assistant"] = "assistant"
|
role: Literal["assistant"] = "assistant"
|
||||||
content: OpenAIChatCompletionMessageContent | None = None
|
content: OpenAIChatCompletionTextOnlyMessageContent | None = None
|
||||||
name: str | None = None
|
name: str | None = None
|
||||||
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
|
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
|
||||||
|
|
||||||
|
@ -534,7 +549,7 @@ class OpenAIToolMessageParam(BaseModel):
|
||||||
|
|
||||||
role: Literal["tool"] = "tool"
|
role: Literal["tool"] = "tool"
|
||||||
tool_call_id: str
|
tool_call_id: str
|
||||||
content: OpenAIChatCompletionMessageContent
|
content: OpenAIChatCompletionTextOnlyMessageContent
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -547,7 +562,7 @@ class OpenAIDeveloperMessageParam(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
role: Literal["developer"] = "developer"
|
role: Literal["developer"] = "developer"
|
||||||
content: OpenAIChatCompletionMessageContent
|
content: OpenAIChatCompletionTextOnlyMessageContent
|
||||||
name: str | None = None
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,10 +18,6 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
||||||
|
|
||||||
# mounting is not supported by docker buildx, so we use COPY instead
|
# mounting is not supported by docker buildx, so we use COPY instead
|
||||||
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
|
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
|
||||||
|
|
||||||
# Mount command for cache container .cache, can be overridden by the user if needed
|
|
||||||
MOUNT_CACHE=${MOUNT_CACHE:-"--mount=type=cache,id=llama-stack-cache,target=/root/.cache"}
|
|
||||||
|
|
||||||
# Path to the run.yaml file in the container
|
# Path to the run.yaml file in the container
|
||||||
RUN_CONFIG_PATH=/app/run.yaml
|
RUN_CONFIG_PATH=/app/run.yaml
|
||||||
|
|
||||||
|
@ -176,18 +172,13 @@ RUN pip install uv
|
||||||
EOF
|
EOF
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Set the link mode to copy so that uv doesn't attempt to symlink to the cache directory
|
|
||||||
add_to_container << EOF
|
|
||||||
ENV UV_LINK_MODE=copy
|
|
||||||
EOF
|
|
||||||
|
|
||||||
# Add pip dependencies first since llama-stack is what will change most often
|
# Add pip dependencies first since llama-stack is what will change most often
|
||||||
# so we can reuse layers.
|
# so we can reuse layers.
|
||||||
if [ -n "$normal_deps" ]; then
|
if [ -n "$normal_deps" ]; then
|
||||||
read -ra pip_args <<< "$normal_deps"
|
read -ra pip_args <<< "$normal_deps"
|
||||||
quoted_deps=$(printf " %q" "${pip_args[@]}")
|
quoted_deps=$(printf " %q" "${pip_args[@]}")
|
||||||
add_to_container << EOF
|
add_to_container << EOF
|
||||||
RUN $MOUNT_CACHE uv pip install $quoted_deps
|
RUN uv pip install --no-cache $quoted_deps
|
||||||
EOF
|
EOF
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
@ -197,7 +188,7 @@ if [ -n "$optional_deps" ]; then
|
||||||
read -ra pip_args <<< "$part"
|
read -ra pip_args <<< "$part"
|
||||||
quoted_deps=$(printf " %q" "${pip_args[@]}")
|
quoted_deps=$(printf " %q" "${pip_args[@]}")
|
||||||
add_to_container <<EOF
|
add_to_container <<EOF
|
||||||
RUN $MOUNT_CACHE uv pip install $quoted_deps
|
RUN uv pip install --no-cache $quoted_deps
|
||||||
EOF
|
EOF
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
@ -208,10 +199,10 @@ if [ -n "$external_provider_deps" ]; then
|
||||||
read -ra pip_args <<< "$part"
|
read -ra pip_args <<< "$part"
|
||||||
quoted_deps=$(printf " %q" "${pip_args[@]}")
|
quoted_deps=$(printf " %q" "${pip_args[@]}")
|
||||||
add_to_container <<EOF
|
add_to_container <<EOF
|
||||||
RUN $MOUNT_CACHE uv pip install $quoted_deps
|
RUN uv pip install --no-cache $quoted_deps
|
||||||
EOF
|
EOF
|
||||||
add_to_container <<EOF
|
add_to_container <<EOF
|
||||||
RUN python3 - <<PYTHON | $MOUNT_CACHE uv pip install -r -
|
RUN python3 - <<PYTHON | uv pip install --no-cache -r -
|
||||||
import importlib
|
import importlib
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
@ -293,7 +284,7 @@ COPY $dir $mount_point
|
||||||
EOF
|
EOF
|
||||||
fi
|
fi
|
||||||
add_to_container << EOF
|
add_to_container << EOF
|
||||||
RUN $MOUNT_CACHE uv pip install -e $mount_point
|
RUN uv pip install --no-cache -e $mount_point
|
||||||
EOF
|
EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -308,10 +299,10 @@ else
|
||||||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||||
# these packages are damaged in test-pypi, so install them first
|
# these packages are damaged in test-pypi, so install them first
|
||||||
add_to_container << EOF
|
add_to_container << EOF
|
||||||
RUN $MOUNT_CACHE uv pip install fastapi libcst
|
RUN uv pip install --no-cache fastapi libcst
|
||||||
EOF
|
EOF
|
||||||
add_to_container << EOF
|
add_to_container << EOF
|
||||||
RUN $MOUNT_CACHE uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
|
||||||
--index-strategy unsafe-best-match \
|
--index-strategy unsafe-best-match \
|
||||||
llama-stack==$TEST_PYPI_VERSION
|
llama-stack==$TEST_PYPI_VERSION
|
||||||
|
|
||||||
|
@ -323,7 +314,7 @@ EOF
|
||||||
SPEC_VERSION="llama-stack"
|
SPEC_VERSION="llama-stack"
|
||||||
fi
|
fi
|
||||||
add_to_container << EOF
|
add_to_container << EOF
|
||||||
RUN $MOUNT_CACHE uv pip install $SPEC_VERSION
|
RUN uv pip install --no-cache $SPEC_VERSION
|
||||||
EOF
|
EOF
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
|
@ -358,7 +358,7 @@ async def shutdown_stack(impls: dict[Api, Any]):
|
||||||
|
|
||||||
|
|
||||||
async def refresh_registry_once(impls: dict[Api, Any]):
|
async def refresh_registry_once(impls: dict[Api, Any]):
|
||||||
logger.info("refreshing registry")
|
logger.debug("refreshing registry")
|
||||||
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
|
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
|
||||||
for routing_table in routing_tables:
|
for routing_table in routing_tables:
|
||||||
await routing_table.refresh()
|
await routing_table.refresh()
|
||||||
|
|
|
@ -469,7 +469,7 @@ class HFFinetuningSingleDevice:
|
||||||
use_cpu=True if device.type == "cpu" and not torch.backends.mps.is_available() else False,
|
use_cpu=True if device.type == "cpu" and not torch.backends.mps.is_available() else False,
|
||||||
save_strategy=save_strategy,
|
save_strategy=save_strategy,
|
||||||
report_to="none",
|
report_to="none",
|
||||||
max_seq_length=provider_config.max_seq_length,
|
max_length=provider_config.max_seq_length,
|
||||||
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
||||||
gradient_checkpointing=provider_config.gradient_checkpointing,
|
gradient_checkpointing=provider_config.gradient_checkpointing,
|
||||||
learning_rate=lr,
|
learning_rate=lr,
|
||||||
|
|
|
@ -32,7 +32,7 @@ class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||||
LiteLLMOpenAIMixin.__init__(
|
LiteLLMOpenAIMixin.__init__(
|
||||||
self,
|
self,
|
||||||
model_entries=MODEL_ENTRIES,
|
model_entries=MODEL_ENTRIES,
|
||||||
litellm_provider_name="llama",
|
litellm_provider_name="meta_llama",
|
||||||
api_key_from_config=config.api_key,
|
api_key_from_config=config.api_key,
|
||||||
provider_data_api_key_field="llama_api_key",
|
provider_data_api_key_field="llama_api_key",
|
||||||
openai_compat_api_base=config.openai_compat_api_base,
|
openai_compat_api_base=config.openai_compat_api_base,
|
||||||
|
|
|
@ -24,9 +24,19 @@ class OpenAIConfig(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="API key for OpenAI models",
|
description="API key for OpenAI models",
|
||||||
)
|
)
|
||||||
|
base_url: str = Field(
|
||||||
|
default="https://api.openai.com/v1",
|
||||||
|
description="Base URL for OpenAI API",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, api_key: str = "${env.OPENAI_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
def sample_run_config(
|
||||||
|
cls,
|
||||||
|
api_key: str = "${env.OPENAI_API_KEY:=}",
|
||||||
|
base_url: str = "${env.OPENAI_BASE_URL:=https://api.openai.com/v1}",
|
||||||
|
**kwargs,
|
||||||
|
) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
|
"base_url": base_url,
|
||||||
}
|
}
|
||||||
|
|
|
@ -65,9 +65,9 @@ class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||||
"""
|
"""
|
||||||
Get the OpenAI API base URL.
|
Get the OpenAI API base URL.
|
||||||
|
|
||||||
Returns the standard OpenAI API base URL for direct OpenAI API calls.
|
Returns the OpenAI API base URL from the configuration.
|
||||||
"""
|
"""
|
||||||
return "https://api.openai.com/v1"
|
return self.config.base_url
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
await super().initialize()
|
await super().initialize()
|
||||||
|
|
|
@ -73,6 +73,15 @@ class LiteLLMOpenAIMixin(
|
||||||
provider_data_api_key_field: str,
|
provider_data_api_key_field: str,
|
||||||
openai_compat_api_base: str | None = None,
|
openai_compat_api_base: str | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Initialize the LiteLLMOpenAIMixin.
|
||||||
|
|
||||||
|
:param model_entries: The model entries to register.
|
||||||
|
:param api_key_from_config: The API key to use from the config.
|
||||||
|
:param provider_data_api_key_field: The field in the provider data that contains the API key.
|
||||||
|
:param litellm_provider_name: The name of the provider, used for model lookups.
|
||||||
|
:param openai_compat_api_base: The base URL for OpenAI compatibility, or None if not using OpenAI compatibility.
|
||||||
|
"""
|
||||||
ModelRegistryHelper.__init__(self, model_entries)
|
ModelRegistryHelper.__init__(self, model_entries)
|
||||||
|
|
||||||
self.litellm_provider_name = litellm_provider_name
|
self.litellm_provider_name = litellm_provider_name
|
||||||
|
@ -428,3 +437,17 @@ class LiteLLMOpenAIMixin(
|
||||||
logprobs: LogProbConfig | None = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
):
|
):
|
||||||
raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat")
|
raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat")
|
||||||
|
|
||||||
|
async def check_model_availability(self, model: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a specific model is available via LiteLLM for the current
|
||||||
|
provider (self.litellm_provider_name).
|
||||||
|
|
||||||
|
:param model: The model identifier to check.
|
||||||
|
:return: True if the model is available dynamically, False otherwise.
|
||||||
|
"""
|
||||||
|
if self.litellm_provider_name not in litellm.models_by_provider:
|
||||||
|
logger.error(f"Provider {self.litellm_provider_name} is not registered in litellm.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return model in litellm.models_by_provider[self.litellm_provider_name]
|
||||||
|
|
|
@ -56,6 +56,7 @@ providers:
|
||||||
provider_type: remote::openai
|
provider_type: remote::openai
|
||||||
config:
|
config:
|
||||||
api_key: ${env.OPENAI_API_KEY:=}
|
api_key: ${env.OPENAI_API_KEY:=}
|
||||||
|
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
|
||||||
- provider_id: anthropic
|
- provider_id: anthropic
|
||||||
provider_type: remote::anthropic
|
provider_type: remote::anthropic
|
||||||
config:
|
config:
|
||||||
|
|
|
@ -16,6 +16,7 @@ providers:
|
||||||
provider_type: remote::openai
|
provider_type: remote::openai
|
||||||
config:
|
config:
|
||||||
api_key: ${env.OPENAI_API_KEY:=}
|
api_key: ${env.OPENAI_API_KEY:=}
|
||||||
|
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
|
||||||
- provider_id: anthropic
|
- provider_id: anthropic
|
||||||
provider_type: remote::anthropic
|
provider_type: remote::anthropic
|
||||||
config:
|
config:
|
||||||
|
|
|
@ -56,6 +56,7 @@ providers:
|
||||||
provider_type: remote::openai
|
provider_type: remote::openai
|
||||||
config:
|
config:
|
||||||
api_key: ${env.OPENAI_API_KEY:=}
|
api_key: ${env.OPENAI_API_KEY:=}
|
||||||
|
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
|
||||||
- provider_id: anthropic
|
- provider_id: anthropic
|
||||||
provider_type: remote::anthropic
|
provider_type: remote::anthropic
|
||||||
config:
|
config:
|
||||||
|
|
|
@ -20,7 +20,7 @@
|
||||||
"@radix-ui/react-tooltip": "^1.2.6",
|
"@radix-ui/react-tooltip": "^1.2.6",
|
||||||
"class-variance-authority": "^0.7.1",
|
"class-variance-authority": "^0.7.1",
|
||||||
"clsx": "^2.1.1",
|
"clsx": "^2.1.1",
|
||||||
"llama-stack-client": "^0.2.15",
|
"llama-stack-client": ""0.2.16",
|
||||||
"lucide-react": "^0.510.0",
|
"lucide-react": "^0.510.0",
|
||||||
"next": "15.3.3",
|
"next": "15.3.3",
|
||||||
"next-auth": "^4.24.11",
|
"next-auth": "^4.24.11",
|
||||||
|
|
|
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "llama_stack"
|
name = "llama_stack"
|
||||||
version = "0.2.15"
|
version = "0.2.16"
|
||||||
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
|
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
|
||||||
description = "Llama Stack"
|
description = "Llama Stack"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
@ -28,7 +28,7 @@ dependencies = [
|
||||||
"huggingface-hub>=0.34.0,<1.0",
|
"huggingface-hub>=0.34.0,<1.0",
|
||||||
"jinja2>=3.1.6",
|
"jinja2>=3.1.6",
|
||||||
"jsonschema",
|
"jsonschema",
|
||||||
"llama-stack-client>=0.2.15",
|
"llama-stack-client>=0.2.16",
|
||||||
"llama-api-client>=0.1.2",
|
"llama-api-client>=0.1.2",
|
||||||
"openai>=1.66",
|
"openai>=1.66",
|
||||||
"prompt-toolkit",
|
"prompt-toolkit",
|
||||||
|
@ -53,7 +53,7 @@ dependencies = [
|
||||||
ui = [
|
ui = [
|
||||||
"streamlit",
|
"streamlit",
|
||||||
"pandas",
|
"pandas",
|
||||||
"llama-stack-client>=0.2.15",
|
"llama-stack-client>=0.2.16",
|
||||||
"streamlit-option-menu",
|
"streamlit-option-menu",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -114,6 +114,7 @@ test = [
|
||||||
"sqlalchemy[asyncio]>=2.0.41",
|
"sqlalchemy[asyncio]>=2.0.41",
|
||||||
"requests",
|
"requests",
|
||||||
"pymilvus>=2.5.12",
|
"pymilvus>=2.5.12",
|
||||||
|
"reportlab",
|
||||||
]
|
]
|
||||||
docs = [
|
docs = [
|
||||||
"setuptools",
|
"setuptools",
|
||||||
|
|
269
requirements.txt
269
requirements.txt
|
@ -1,269 +0,0 @@
|
||||||
# This file was autogenerated by uv via the following command:
|
|
||||||
# uv export --frozen --no-hashes --no-emit-project --no-default-groups --output-file=requirements.txt
|
|
||||||
aiohappyeyeballs==2.5.0
|
|
||||||
# via aiohttp
|
|
||||||
aiohttp==3.12.13
|
|
||||||
# via llama-stack
|
|
||||||
aiosignal==1.3.2
|
|
||||||
# via aiohttp
|
|
||||||
aiosqlite==0.21.0
|
|
||||||
# via llama-stack
|
|
||||||
annotated-types==0.7.0
|
|
||||||
# via pydantic
|
|
||||||
anyio==4.8.0
|
|
||||||
# via
|
|
||||||
# httpx
|
|
||||||
# llama-api-client
|
|
||||||
# llama-stack-client
|
|
||||||
# openai
|
|
||||||
# starlette
|
|
||||||
asyncpg==0.30.0
|
|
||||||
# via llama-stack
|
|
||||||
attrs==25.1.0
|
|
||||||
# via
|
|
||||||
# aiohttp
|
|
||||||
# jsonschema
|
|
||||||
# referencing
|
|
||||||
certifi==2025.1.31
|
|
||||||
# via
|
|
||||||
# httpcore
|
|
||||||
# httpx
|
|
||||||
# requests
|
|
||||||
cffi==1.17.1 ; platform_python_implementation != 'PyPy'
|
|
||||||
# via cryptography
|
|
||||||
charset-normalizer==3.4.1
|
|
||||||
# via requests
|
|
||||||
click==8.1.8
|
|
||||||
# via
|
|
||||||
# llama-stack-client
|
|
||||||
# uvicorn
|
|
||||||
colorama==0.4.6 ; sys_platform == 'win32'
|
|
||||||
# via
|
|
||||||
# click
|
|
||||||
# tqdm
|
|
||||||
cryptography==45.0.5
|
|
||||||
# via python-jose
|
|
||||||
deprecated==1.2.18
|
|
||||||
# via
|
|
||||||
# opentelemetry-api
|
|
||||||
# opentelemetry-exporter-otlp-proto-http
|
|
||||||
# opentelemetry-semantic-conventions
|
|
||||||
distro==1.9.0
|
|
||||||
# via
|
|
||||||
# llama-api-client
|
|
||||||
# llama-stack-client
|
|
||||||
# openai
|
|
||||||
ecdsa==0.19.1
|
|
||||||
# via python-jose
|
|
||||||
fastapi==0.115.8
|
|
||||||
# via llama-stack
|
|
||||||
filelock==3.17.0
|
|
||||||
# via huggingface-hub
|
|
||||||
fire==0.7.0
|
|
||||||
# via
|
|
||||||
# llama-stack
|
|
||||||
# llama-stack-client
|
|
||||||
frozenlist==1.5.0
|
|
||||||
# via
|
|
||||||
# aiohttp
|
|
||||||
# aiosignal
|
|
||||||
fsspec==2024.12.0
|
|
||||||
# via huggingface-hub
|
|
||||||
googleapis-common-protos==1.67.0
|
|
||||||
# via opentelemetry-exporter-otlp-proto-http
|
|
||||||
h11==0.16.0
|
|
||||||
# via
|
|
||||||
# httpcore
|
|
||||||
# llama-stack
|
|
||||||
# uvicorn
|
|
||||||
hf-xet==1.1.5 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
|
|
||||||
# via huggingface-hub
|
|
||||||
httpcore==1.0.9
|
|
||||||
# via httpx
|
|
||||||
httpx==0.28.1
|
|
||||||
# via
|
|
||||||
# llama-api-client
|
|
||||||
# llama-stack
|
|
||||||
# llama-stack-client
|
|
||||||
# openai
|
|
||||||
huggingface-hub==0.34.1
|
|
||||||
# via llama-stack
|
|
||||||
idna==3.10
|
|
||||||
# via
|
|
||||||
# anyio
|
|
||||||
# httpx
|
|
||||||
# requests
|
|
||||||
# yarl
|
|
||||||
importlib-metadata==8.5.0
|
|
||||||
# via opentelemetry-api
|
|
||||||
jinja2==3.1.6
|
|
||||||
# via llama-stack
|
|
||||||
jiter==0.8.2
|
|
||||||
# via openai
|
|
||||||
jsonschema==4.23.0
|
|
||||||
# via llama-stack
|
|
||||||
jsonschema-specifications==2024.10.1
|
|
||||||
# via jsonschema
|
|
||||||
llama-api-client==0.1.2
|
|
||||||
# via llama-stack
|
|
||||||
llama-stack-client==0.2.15
|
|
||||||
# via llama-stack
|
|
||||||
markdown-it-py==3.0.0
|
|
||||||
# via rich
|
|
||||||
markupsafe==3.0.2
|
|
||||||
# via jinja2
|
|
||||||
mdurl==0.1.2
|
|
||||||
# via markdown-it-py
|
|
||||||
multidict==6.1.0
|
|
||||||
# via
|
|
||||||
# aiohttp
|
|
||||||
# yarl
|
|
||||||
numpy==2.2.3
|
|
||||||
# via pandas
|
|
||||||
openai==1.71.0
|
|
||||||
# via llama-stack
|
|
||||||
opentelemetry-api==1.30.0
|
|
||||||
# via
|
|
||||||
# opentelemetry-exporter-otlp-proto-http
|
|
||||||
# opentelemetry-sdk
|
|
||||||
# opentelemetry-semantic-conventions
|
|
||||||
opentelemetry-exporter-otlp-proto-common==1.30.0
|
|
||||||
# via opentelemetry-exporter-otlp-proto-http
|
|
||||||
opentelemetry-exporter-otlp-proto-http==1.30.0
|
|
||||||
# via llama-stack
|
|
||||||
opentelemetry-proto==1.30.0
|
|
||||||
# via
|
|
||||||
# opentelemetry-exporter-otlp-proto-common
|
|
||||||
# opentelemetry-exporter-otlp-proto-http
|
|
||||||
opentelemetry-sdk==1.30.0
|
|
||||||
# via
|
|
||||||
# llama-stack
|
|
||||||
# opentelemetry-exporter-otlp-proto-http
|
|
||||||
opentelemetry-semantic-conventions==0.51b0
|
|
||||||
# via opentelemetry-sdk
|
|
||||||
packaging==24.2
|
|
||||||
# via huggingface-hub
|
|
||||||
pandas==2.2.3
|
|
||||||
# via llama-stack-client
|
|
||||||
pillow==11.1.0
|
|
||||||
# via llama-stack
|
|
||||||
prompt-toolkit==3.0.50
|
|
||||||
# via
|
|
||||||
# llama-stack
|
|
||||||
# llama-stack-client
|
|
||||||
propcache==0.3.0
|
|
||||||
# via
|
|
||||||
# aiohttp
|
|
||||||
# yarl
|
|
||||||
protobuf==5.29.5
|
|
||||||
# via
|
|
||||||
# googleapis-common-protos
|
|
||||||
# opentelemetry-proto
|
|
||||||
pyaml==25.1.0
|
|
||||||
# via llama-stack-client
|
|
||||||
pyasn1==0.4.8
|
|
||||||
# via
|
|
||||||
# python-jose
|
|
||||||
# rsa
|
|
||||||
pycparser==2.22 ; platform_python_implementation != 'PyPy'
|
|
||||||
# via cffi
|
|
||||||
pydantic==2.10.6
|
|
||||||
# via
|
|
||||||
# fastapi
|
|
||||||
# llama-api-client
|
|
||||||
# llama-stack
|
|
||||||
# llama-stack-client
|
|
||||||
# openai
|
|
||||||
pydantic-core==2.27.2
|
|
||||||
# via pydantic
|
|
||||||
pygments==2.19.1
|
|
||||||
# via rich
|
|
||||||
python-dateutil==2.9.0.post0
|
|
||||||
# via pandas
|
|
||||||
python-dotenv==1.0.1
|
|
||||||
# via llama-stack
|
|
||||||
python-jose==3.4.0
|
|
||||||
# via llama-stack
|
|
||||||
python-multipart==0.0.20
|
|
||||||
# via llama-stack
|
|
||||||
pytz==2025.1
|
|
||||||
# via pandas
|
|
||||||
pyyaml==6.0.2
|
|
||||||
# via
|
|
||||||
# huggingface-hub
|
|
||||||
# pyaml
|
|
||||||
referencing==0.36.2
|
|
||||||
# via
|
|
||||||
# jsonschema
|
|
||||||
# jsonschema-specifications
|
|
||||||
regex==2024.11.6
|
|
||||||
# via tiktoken
|
|
||||||
requests==2.32.4
|
|
||||||
# via
|
|
||||||
# huggingface-hub
|
|
||||||
# llama-stack-client
|
|
||||||
# opentelemetry-exporter-otlp-proto-http
|
|
||||||
# tiktoken
|
|
||||||
rich==13.9.4
|
|
||||||
# via
|
|
||||||
# llama-stack
|
|
||||||
# llama-stack-client
|
|
||||||
rpds-py==0.22.3
|
|
||||||
# via
|
|
||||||
# jsonschema
|
|
||||||
# referencing
|
|
||||||
rsa==4.9
|
|
||||||
# via python-jose
|
|
||||||
six==1.17.0
|
|
||||||
# via
|
|
||||||
# ecdsa
|
|
||||||
# python-dateutil
|
|
||||||
sniffio==1.3.1
|
|
||||||
# via
|
|
||||||
# anyio
|
|
||||||
# llama-api-client
|
|
||||||
# llama-stack-client
|
|
||||||
# openai
|
|
||||||
starlette==0.45.3
|
|
||||||
# via
|
|
||||||
# fastapi
|
|
||||||
# llama-stack
|
|
||||||
termcolor==2.5.0
|
|
||||||
# via
|
|
||||||
# fire
|
|
||||||
# llama-stack
|
|
||||||
# llama-stack-client
|
|
||||||
tiktoken==0.9.0
|
|
||||||
# via llama-stack
|
|
||||||
tqdm==4.67.1
|
|
||||||
# via
|
|
||||||
# huggingface-hub
|
|
||||||
# llama-stack-client
|
|
||||||
# openai
|
|
||||||
typing-extensions==4.12.2
|
|
||||||
# via
|
|
||||||
# aiosqlite
|
|
||||||
# anyio
|
|
||||||
# fastapi
|
|
||||||
# huggingface-hub
|
|
||||||
# llama-api-client
|
|
||||||
# llama-stack-client
|
|
||||||
# openai
|
|
||||||
# opentelemetry-sdk
|
|
||||||
# pydantic
|
|
||||||
# pydantic-core
|
|
||||||
# referencing
|
|
||||||
tzdata==2025.1
|
|
||||||
# via pandas
|
|
||||||
urllib3==2.5.0
|
|
||||||
# via requests
|
|
||||||
uvicorn==0.34.0
|
|
||||||
# via llama-stack
|
|
||||||
wcwidth==0.2.13
|
|
||||||
# via prompt-toolkit
|
|
||||||
wrapt==1.17.2
|
|
||||||
# via deprecated
|
|
||||||
yarl==1.18.3
|
|
||||||
# via aiohttp
|
|
||||||
zipp==3.21.0
|
|
||||||
# via importlib-metadata
|
|
|
@ -8,6 +8,15 @@
|
||||||
|
|
||||||
PYTHON_VERSION=${PYTHON_VERSION:-3.12}
|
PYTHON_VERSION=${PYTHON_VERSION:-3.12}
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Always run this at the end, even if something fails
|
||||||
|
cleanup() {
|
||||||
|
echo "Generating coverage report..."
|
||||||
|
uv run --python "$PYTHON_VERSION" coverage html -d htmlcov-$PYTHON_VERSION
|
||||||
|
}
|
||||||
|
trap cleanup EXIT
|
||||||
|
|
||||||
command -v uv >/dev/null 2>&1 || { echo >&2 "uv is required but it's not installed. Exiting."; exit 1; }
|
command -v uv >/dev/null 2>&1 || { echo >&2 "uv is required but it's not installed. Exiting."; exit 1; }
|
||||||
|
|
||||||
uv python find "$PYTHON_VERSION"
|
uv python find "$PYTHON_VERSION"
|
||||||
|
@ -19,6 +28,3 @@ fi
|
||||||
# Run unit tests with coverage
|
# Run unit tests with coverage
|
||||||
uv run --python "$PYTHON_VERSION" --with-editable . --group unit \
|
uv run --python "$PYTHON_VERSION" --with-editable . --group unit \
|
||||||
coverage run --source=llama_stack -m pytest -s -v tests/unit/ "$@"
|
coverage run --source=llama_stack -m pytest -s -v tests/unit/ "$@"
|
||||||
|
|
||||||
# Generate HTML coverage report
|
|
||||||
uv run --python "$PYTHON_VERSION" coverage html -d htmlcov-$PYTHON_VERSION
|
|
||||||
|
|
3
tests/external/build.yaml
vendored
3
tests/external/build.yaml
vendored
|
@ -3,8 +3,7 @@ distribution_spec:
|
||||||
description: Custom distro for CI tests
|
description: Custom distro for CI tests
|
||||||
providers:
|
providers:
|
||||||
weather:
|
weather:
|
||||||
- provider_id: kaze
|
- provider_type: remote::kaze
|
||||||
provider_type: remote::kaze
|
|
||||||
image_type: venv
|
image_type: venv
|
||||||
image_name: ci-test
|
image_name: ci-test
|
||||||
external_providers_dir: ~/.llama/providers.d
|
external_providers_dir: ~/.llama/providers.d
|
||||||
|
|
3
tests/external/ramalama-stack/build.yaml
vendored
3
tests/external/ramalama-stack/build.yaml
vendored
|
@ -4,8 +4,7 @@ distribution_spec:
|
||||||
container_image: null
|
container_image: null
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- provider_id: ramalama
|
- provider_type: remote::ramalama
|
||||||
provider_type: remote::ramalama
|
|
||||||
module: ramalama_stack==0.3.0a0
|
module: ramalama_stack==0.3.0a0
|
||||||
image_type: venv
|
image_type: venv
|
||||||
image_name: ramalama-stack-test
|
image_name: ramalama-stack-test
|
||||||
|
|
|
@ -5,8 +5,14 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
from reportlab.lib.pagesizes import letter
|
||||||
|
from reportlab.pdfgen import canvas
|
||||||
|
|
||||||
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||||
|
|
||||||
|
@ -82,6 +88,14 @@ def skip_if_provider_isnt_vllm(client_with_models, model_id):
|
||||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support vllm extra_body parameters.")
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support vllm extra_body parameters.")
|
||||||
|
|
||||||
|
|
||||||
|
def skip_if_provider_isnt_openai(client_with_models, model_id):
|
||||||
|
provider = provider_from_model(client_with_models, model_id)
|
||||||
|
if provider.provider_type != "remote::openai":
|
||||||
|
pytest.skip(
|
||||||
|
f"Model {model_id} hosted by {provider.provider_type} doesn't support chat completion calls with base64 encoded files."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def openai_client(client_with_models):
|
def openai_client(client_with_models):
|
||||||
base_url = f"{client_with_models.base_url}/v1/openai/v1"
|
base_url = f"{client_with_models.base_url}/v1/openai/v1"
|
||||||
|
@ -418,3 +432,45 @@ def test_inference_store_tool_calls(compat_client, client_with_models, text_mode
|
||||||
# failed tool call parses show up as a message with content, so ensure
|
# failed tool call parses show up as a message with content, so ensure
|
||||||
# that the retrieve response content matches the original request
|
# that the retrieve response content matches the original request
|
||||||
assert retrieved_response.choices[0].message.content == content
|
assert retrieved_response.choices[0].message.content == content
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_chat_completion_non_streaming_with_file(openai_client, client_with_models, text_model_id):
|
||||||
|
skip_if_provider_isnt_openai(client_with_models, text_model_id)
|
||||||
|
|
||||||
|
# Generate temporary PDF with "Hello World" text
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as temp_pdf:
|
||||||
|
c = canvas.Canvas(temp_pdf.name, pagesize=letter)
|
||||||
|
c.drawString(100, 750, "Hello World")
|
||||||
|
c.save()
|
||||||
|
|
||||||
|
# Read the PDF and sencode to base64
|
||||||
|
with open(temp_pdf.name, "rb") as pdf_file:
|
||||||
|
pdf_base64 = base64.b64encode(pdf_file.read()).decode("utf-8")
|
||||||
|
|
||||||
|
# Clean up temporary file
|
||||||
|
os.unlink(temp_pdf.name)
|
||||||
|
|
||||||
|
response = openai_client.chat.completions.create(
|
||||||
|
model=text_model_id,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Describe what you see in this PDF file.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "file",
|
||||||
|
"file": {
|
||||||
|
"filename": "my-temp-hello-world-pdf",
|
||||||
|
"file_data": f"data:application/pdf;base64,{pdf_base64}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
message_content = response.choices[0].message.content.lower().strip()
|
||||||
|
assert "hello world" in message_content
|
||||||
|
|
|
@ -38,9 +38,8 @@ sys.stdout.reconfigure(line_buffering=True)
|
||||||
|
|
||||||
# How to run this test:
|
# How to run this test:
|
||||||
#
|
#
|
||||||
# pytest llama_stack/providers/tests/post_training/test_post_training.py
|
# LLAMA_STACK_CONFIG=ci-tests uv run --dev pytest tests/integration/post_training/test_post_training.py
|
||||||
# -m "torchtune_post_training_huggingface_datasetio"
|
#
|
||||||
# -v -s --tb=short --disable-warnings
|
|
||||||
|
|
||||||
|
|
||||||
class TestPostTraining:
|
class TestPostTraining:
|
||||||
|
@ -113,6 +112,7 @@ class TestPostTraining:
|
||||||
break
|
break
|
||||||
|
|
||||||
logger.info(f"Current status: {status}")
|
logger.info(f"Current status: {status}")
|
||||||
|
assert status.status in ["scheduled", "in_progress", "completed"]
|
||||||
if status.status == "completed":
|
if status.status == "completed":
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
|
@ -346,7 +346,7 @@ pip_packages:
|
||||||
|
|
||||||
def test_external_provider_from_module_building(self, mock_providers):
|
def test_external_provider_from_module_building(self, mock_providers):
|
||||||
"""Test loading an external provider from a module during build (building=True, partial spec)."""
|
"""Test loading an external provider from a module during build (building=True, partial spec)."""
|
||||||
from llama_stack.distribution.datatypes import BuildConfig, DistributionSpec, Provider
|
from llama_stack.distribution.datatypes import BuildConfig, BuildProvider, DistributionSpec
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
# No importlib patch needed, should not import module when type of `config` is BuildConfig or DistributionSpec
|
# No importlib patch needed, should not import module when type of `config` is BuildConfig or DistributionSpec
|
||||||
|
@ -358,10 +358,8 @@ pip_packages:
|
||||||
description="test",
|
description="test",
|
||||||
providers={
|
providers={
|
||||||
"inference": [
|
"inference": [
|
||||||
Provider(
|
BuildProvider(
|
||||||
provider_id="external_test",
|
|
||||||
provider_type="external_test",
|
provider_type="external_test",
|
||||||
config={},
|
|
||||||
module="external_test",
|
module="external_test",
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
125
tests/unit/providers/inference/test_openai_base_url_config.py
Normal file
125
tests/unit/providers/inference/test_openai_base_url_config.py
Normal file
|
@ -0,0 +1,125 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from llama_stack.distribution.stack import replace_env_vars
|
||||||
|
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
||||||
|
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIBaseURLConfig:
|
||||||
|
"""Test that OPENAI_BASE_URL environment variable properly configures the OpenAI adapter."""
|
||||||
|
|
||||||
|
def test_default_base_url_without_env_var(self):
|
||||||
|
"""Test that the adapter uses the default OpenAI base URL when no environment variable is set."""
|
||||||
|
config = OpenAIConfig(api_key="test-key")
|
||||||
|
adapter = OpenAIInferenceAdapter(config)
|
||||||
|
|
||||||
|
assert adapter.get_base_url() == "https://api.openai.com/v1"
|
||||||
|
|
||||||
|
def test_custom_base_url_from_config(self):
|
||||||
|
"""Test that the adapter uses a custom base URL when provided in config."""
|
||||||
|
custom_url = "https://custom.openai.com/v1"
|
||||||
|
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
|
||||||
|
adapter = OpenAIInferenceAdapter(config)
|
||||||
|
|
||||||
|
assert adapter.get_base_url() == custom_url
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://env.openai.com/v1"})
|
||||||
|
def test_base_url_from_environment_variable(self):
|
||||||
|
"""Test that the adapter uses base URL from OPENAI_BASE_URL environment variable."""
|
||||||
|
# Use sample_run_config which has proper environment variable syntax
|
||||||
|
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
|
||||||
|
processed_config = replace_env_vars(config_data)
|
||||||
|
config = OpenAIConfig.model_validate(processed_config)
|
||||||
|
adapter = OpenAIInferenceAdapter(config)
|
||||||
|
|
||||||
|
assert adapter.get_base_url() == "https://env.openai.com/v1"
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://env.openai.com/v1"})
|
||||||
|
def test_config_overrides_environment_variable(self):
|
||||||
|
"""Test that explicit config value overrides environment variable."""
|
||||||
|
custom_url = "https://config.openai.com/v1"
|
||||||
|
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
|
||||||
|
adapter = OpenAIInferenceAdapter(config)
|
||||||
|
|
||||||
|
# Config should take precedence over environment variable
|
||||||
|
assert adapter.get_base_url() == custom_url
|
||||||
|
|
||||||
|
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
|
||||||
|
def test_client_uses_configured_base_url(self, mock_openai_class):
|
||||||
|
"""Test that the OpenAI client is initialized with the configured base URL."""
|
||||||
|
custom_url = "https://test.openai.com/v1"
|
||||||
|
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
|
||||||
|
adapter = OpenAIInferenceAdapter(config)
|
||||||
|
|
||||||
|
# Mock the get_api_key method since it's delegated to LiteLLMOpenAIMixin
|
||||||
|
adapter.get_api_key = MagicMock(return_value="test-key")
|
||||||
|
|
||||||
|
# Access the client property to trigger AsyncOpenAI initialization
|
||||||
|
_ = adapter.client
|
||||||
|
|
||||||
|
# Verify AsyncOpenAI was called with the correct base_url
|
||||||
|
mock_openai_class.assert_called_once_with(
|
||||||
|
api_key="test-key",
|
||||||
|
base_url=custom_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
|
||||||
|
async def test_check_model_availability_uses_configured_url(self, mock_openai_class):
|
||||||
|
"""Test that check_model_availability uses the configured base URL."""
|
||||||
|
custom_url = "https://test.openai.com/v1"
|
||||||
|
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
|
||||||
|
adapter = OpenAIInferenceAdapter(config)
|
||||||
|
|
||||||
|
# Mock the get_api_key method
|
||||||
|
adapter.get_api_key = MagicMock(return_value="test-key")
|
||||||
|
|
||||||
|
# Mock the AsyncOpenAI client and its models.retrieve method
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.models.retrieve = AsyncMock(return_value=MagicMock())
|
||||||
|
mock_openai_class.return_value = mock_client
|
||||||
|
|
||||||
|
# Call check_model_availability and verify it returns True
|
||||||
|
assert await adapter.check_model_availability("gpt-4")
|
||||||
|
|
||||||
|
# Verify the client was created with the custom URL
|
||||||
|
mock_openai_class.assert_called_with(
|
||||||
|
api_key="test-key",
|
||||||
|
base_url=custom_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the method was called and returned True
|
||||||
|
mock_client.models.retrieve.assert_called_once_with("gpt-4")
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://proxy.openai.com/v1"})
|
||||||
|
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
|
||||||
|
async def test_environment_variable_affects_model_availability_check(self, mock_openai_class):
|
||||||
|
"""Test that setting OPENAI_BASE_URL environment variable affects where model availability is checked."""
|
||||||
|
# Use sample_run_config which has proper environment variable syntax
|
||||||
|
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
|
||||||
|
processed_config = replace_env_vars(config_data)
|
||||||
|
config = OpenAIConfig.model_validate(processed_config)
|
||||||
|
adapter = OpenAIInferenceAdapter(config)
|
||||||
|
|
||||||
|
# Mock the get_api_key method
|
||||||
|
adapter.get_api_key = MagicMock(return_value="test-key")
|
||||||
|
|
||||||
|
# Mock the AsyncOpenAI client
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.models.retrieve = AsyncMock(return_value=MagicMock())
|
||||||
|
mock_openai_class.return_value = mock_client
|
||||||
|
|
||||||
|
# Call check_model_availability and verify it returns True
|
||||||
|
assert await adapter.check_model_availability("gpt-4")
|
||||||
|
|
||||||
|
# Verify the client was created with the environment variable URL
|
||||||
|
mock_openai_class.assert_called_with(
|
||||||
|
api_key="test-key",
|
||||||
|
base_url="https://proxy.openai.com/v1",
|
||||||
|
)
|
|
@ -4,13 +4,19 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import TextContentItem
|
from llama_stack.apis.common.content_types import TextContentItem
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
OpenAIAssistantMessageParam,
|
OpenAIAssistantMessageParam,
|
||||||
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
OpenAIChatCompletionContentPartTextParam,
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
|
OpenAIDeveloperMessageParam,
|
||||||
|
OpenAIImageURL,
|
||||||
OpenAISystemMessageParam,
|
OpenAISystemMessageParam,
|
||||||
|
OpenAIToolMessageParam,
|
||||||
OpenAIUserMessageParam,
|
OpenAIUserMessageParam,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
|
@ -108,3 +114,71 @@ async def test_openai_messages_to_messages_with_content_list():
|
||||||
assert llama_messages[0].content[0].text == "system message"
|
assert llama_messages[0].content[0].text == "system message"
|
||||||
assert llama_messages[1].content[0].text == "user message"
|
assert llama_messages[1].content[0].text == "user message"
|
||||||
assert llama_messages[2].content[0].text == "assistant message"
|
assert llama_messages[2].content[0].text == "assistant message"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"message_class,kwargs",
|
||||||
|
[
|
||||||
|
(OpenAISystemMessageParam, {}),
|
||||||
|
(OpenAIAssistantMessageParam, {}),
|
||||||
|
(OpenAIDeveloperMessageParam, {}),
|
||||||
|
(OpenAIUserMessageParam, {}),
|
||||||
|
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_message_accepts_text_string(message_class, kwargs):
|
||||||
|
"""Test that messages accept string text content."""
|
||||||
|
msg = message_class(content="Test message", **kwargs)
|
||||||
|
assert msg.content == "Test message"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"message_class,kwargs",
|
||||||
|
[
|
||||||
|
(OpenAISystemMessageParam, {}),
|
||||||
|
(OpenAIAssistantMessageParam, {}),
|
||||||
|
(OpenAIDeveloperMessageParam, {}),
|
||||||
|
(OpenAIUserMessageParam, {}),
|
||||||
|
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_message_accepts_text_list(message_class, kwargs):
|
||||||
|
"""Test that messages accept list of text content parts."""
|
||||||
|
content_list = [OpenAIChatCompletionContentPartTextParam(text="Test message")]
|
||||||
|
msg = message_class(content=content_list, **kwargs)
|
||||||
|
assert len(msg.content) == 1
|
||||||
|
assert msg.content[0].text == "Test message"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"message_class,kwargs",
|
||||||
|
[
|
||||||
|
(OpenAISystemMessageParam, {}),
|
||||||
|
(OpenAIAssistantMessageParam, {}),
|
||||||
|
(OpenAIDeveloperMessageParam, {}),
|
||||||
|
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_message_rejects_images(message_class, kwargs):
|
||||||
|
"""Test that system, assistant, developer, and tool messages reject image content."""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
message_class(
|
||||||
|
content=[
|
||||||
|
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg"))
|
||||||
|
],
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_message_accepts_images():
|
||||||
|
"""Test that user messages accept image content (unlike other message types)."""
|
||||||
|
# List with images should work
|
||||||
|
msg = OpenAIUserMessageParam(
|
||||||
|
content=[
|
||||||
|
OpenAIChatCompletionContentPartTextParam(text="Describe this image:"),
|
||||||
|
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert len(msg.content) == 2
|
||||||
|
assert msg.content[0].text == "Describe this image:"
|
||||||
|
assert msg.content[1].image_url.url == "http://example.com/image.jpg"
|
||||||
|
|
|
@ -162,26 +162,29 @@ async def test_register_model_existing_different(
|
||||||
await helper.register_model(known_model)
|
await helper.register_model(known_model)
|
||||||
|
|
||||||
|
|
||||||
async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model) -> None:
|
# TODO: unregister_model functionality was removed/disabled by https://github.com/meta-llama/llama-stack/pull/2916
|
||||||
await helper.register_model(known_model) # duplicate entry
|
# async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||||
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_model_id
|
# await helper.register_model(known_model) # duplicate entry
|
||||||
await helper.unregister_model(known_model.model_id)
|
# assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_model_id
|
||||||
assert helper.get_provider_model_id(known_model.model_id) is None
|
# await helper.unregister_model(known_model.model_id)
|
||||||
|
# assert helper.get_provider_model_id(known_model.model_id) is None
|
||||||
|
|
||||||
|
|
||||||
async def test_unregister_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
|
# TODO: unregister_model functionality was removed/disabled by https://github.com/meta-llama/llama-stack/pull/2916
|
||||||
with pytest.raises(ValueError):
|
# async def test_unregister_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
|
||||||
await helper.unregister_model(unknown_model.model_id)
|
# with pytest.raises(ValueError):
|
||||||
|
# await helper.unregister_model(unknown_model.model_id)
|
||||||
|
|
||||||
|
|
||||||
async def test_register_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
|
async def test_register_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||||
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
||||||
|
|
||||||
|
|
||||||
async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
|
# TODO: unregister_model functionality was removed/disabled by https://github.com/meta-llama/llama-stack/pull/2916
|
||||||
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
# async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||||
await helper.unregister_model(known_model.provider_resource_id)
|
# assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
||||||
assert helper.get_provider_model_id(known_model.provider_resource_id) is None
|
# await helper.unregister_model(known_model.provider_resource_id)
|
||||||
|
# assert helper.get_provider_model_id(known_model.provider_resource_id) is None
|
||||||
|
|
||||||
|
|
||||||
async def test_register_model_from_check_model_availability(
|
async def test_register_model_from_check_model_availability(
|
||||||
|
|
|
@ -49,7 +49,7 @@ def github_token_app():
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add auth middleware
|
# Add auth middleware
|
||||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
|
||||||
|
|
||||||
@app.get("/test")
|
@app.get("/test")
|
||||||
def test_endpoint():
|
def test_endpoint():
|
||||||
|
@ -149,7 +149,7 @@ def test_github_enterprise_support(mock_client_class):
|
||||||
access_policy=[],
|
access_policy=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
|
||||||
|
|
||||||
@app.get("/test")
|
@app.get("/test")
|
||||||
def test_endpoint():
|
def test_endpoint():
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue