Merge branch 'main' into patch-1

This commit is contained in:
Dean Wampler 2025-07-29 14:46:11 -04:00 committed by GitHub
commit 92c2edd61c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
35 changed files with 1916 additions and 1589 deletions

View file

@ -14,10 +14,18 @@ concurrency:
jobs: jobs:
pre-commit: pre-commit:
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
# For dependabot PRs, we need to checkout with a token that can push changes
token: ${{ github.actor == 'dependabot[bot]' && secrets.GITHUB_TOKEN || github.token }}
# Fetch full history for dependabot PRs to allow commits
fetch-depth: ${{ github.actor == 'dependabot[bot]' && 0 || 1 }}
- name: Set up Python - name: Set up Python
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
@ -29,15 +37,45 @@ jobs:
.pre-commit-config.yaml .pre-commit-config.yaml
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
continue-on-error: true
env: env:
SKIP: no-commit-to-branch SKIP: no-commit-to-branch
RUFF_OUTPUT_FORMAT: github RUFF_OUTPUT_FORMAT: github
- name: Debug
run: |
echo "github.ref: ${{ github.ref }}"
echo "github.actor: ${{ github.actor }}"
- name: Commit changes for dependabot PRs
if: github.actor == 'dependabot[bot]'
run: |
if ! git diff --exit-code || [ -n "$(git ls-files --others --exclude-standard)" ]; then
git config --local user.email "github-actions[bot]@users.noreply.github.com"
git config --local user.name "github-actions[bot]"
# Ensure we're on the correct branch
git checkout -B ${{ github.head_ref }}
git add -A
git commit -m "Apply pre-commit fixes"
# Pull latest changes from the PR branch and rebase our commit on top
git pull --rebase origin ${{ github.head_ref }}
# Push to the PR branch
git push origin ${{ github.head_ref }}
echo "Pre-commit fixes committed and pushed"
else
echo "No changes to commit"
fi
- name: Verify if there are any diff files after pre-commit - name: Verify if there are any diff files after pre-commit
if: github.actor != 'dependabot[bot]'
run: | run: |
git diff --exit-code || (echo "There are uncommitted changes, run pre-commit locally and commit again" && exit 1) git diff --exit-code || (echo "There are uncommitted changes, run pre-commit locally and commit again" && exit 1)
- name: Verify if there are any new files after pre-commit - name: Verify if there are any new files after pre-commit
if: github.actor != 'dependabot[bot]'
run: | run: |
unstaged_files=$(git ls-files --others --exclude-standard) unstaged_files=$(git ls-files --others --exclude-standard)
if [ -n "$unstaged_files" ]; then if [ -n "$unstaged_files" ]; then

View file

@ -13,6 +13,7 @@ on:
- 'uv.lock' - 'uv.lock'
- 'pyproject.toml' - 'pyproject.toml'
- 'requirements.txt' - 'requirements.txt'
- 'tests/external/*'
- '.github/workflows/test-external-provider-module.yml' # This workflow - '.github/workflows/test-external-provider-module.yml' # This workflow
jobs: jobs:

View file

@ -13,6 +13,7 @@ on:
- 'uv.lock' - 'uv.lock'
- 'pyproject.toml' - 'pyproject.toml'
- 'requirements.txt' - 'requirements.txt'
- 'tests/external/*'
- '.github/workflows/test-external.yml' # This workflow - '.github/workflows/test-external.yml' # This workflow
jobs: jobs:

View file

@ -35,6 +35,8 @@ jobs:
- name: Install dependencies - name: Install dependencies
uses: ./.github/actions/setup-runner uses: ./.github/actions/setup-runner
with:
python-version: ${{ matrix.python }}
- name: Run unit tests - name: Run unit tests
run: | run: |

View file

@ -19,7 +19,6 @@ repos:
- id: check-yaml - id: check-yaml
args: ["--unsafe"] args: ["--unsafe"]
- id: detect-private-key - id: detect-private-key
- id: requirements-txt-fixer
- id: mixed-line-ending - id: mixed-line-ending
args: [--fix=lf] # Forces to replace line ending by LF (line feed) args: [--fix=lf] # Forces to replace line ending by LF (line feed)
- id: check-executables-have-shebangs - id: check-executables-have-shebangs
@ -56,14 +55,6 @@ repos:
rev: 0.7.20 rev: 0.7.20
hooks: hooks:
- id: uv-lock - id: uv-lock
- id: uv-export
args: [
"--frozen",
"--no-hashes",
"--no-emit-project",
"--no-default-groups",
"--output-file=requirements.txt"
]
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.16.1 rev: v1.16.1

View file

@ -9770,7 +9770,7 @@
{ {
"type": "array", "type": "array",
"items": { "items": {
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam" "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
} }
} }
], ],
@ -9821,13 +9821,17 @@
}, },
{ {
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam" "$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
},
{
"$ref": "#/components/schemas/OpenAIFile"
} }
], ],
"discriminator": { "discriminator": {
"propertyName": "type", "propertyName": "type",
"mapping": { "mapping": {
"text": "#/components/schemas/OpenAIChatCompletionContentPartTextParam", "text": "#/components/schemas/OpenAIChatCompletionContentPartTextParam",
"image_url": "#/components/schemas/OpenAIChatCompletionContentPartImageParam" "image_url": "#/components/schemas/OpenAIChatCompletionContentPartImageParam",
"file": "#/components/schemas/OpenAIFile"
} }
} }
}, },
@ -9955,7 +9959,7 @@
{ {
"type": "array", "type": "array",
"items": { "items": {
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam" "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
} }
} }
], ],
@ -9974,6 +9978,41 @@
"title": "OpenAIDeveloperMessageParam", "title": "OpenAIDeveloperMessageParam",
"description": "A message from the developer in an OpenAI-compatible chat completion request." "description": "A message from the developer in an OpenAI-compatible chat completion request."
}, },
"OpenAIFile": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "file",
"default": "file"
},
"file": {
"$ref": "#/components/schemas/OpenAIFileFile"
}
},
"additionalProperties": false,
"required": [
"type",
"file"
],
"title": "OpenAIFile"
},
"OpenAIFileFile": {
"type": "object",
"properties": {
"file_data": {
"type": "string"
},
"file_id": {
"type": "string"
},
"filename": {
"type": "string"
}
},
"additionalProperties": false,
"title": "OpenAIFileFile"
},
"OpenAIImageURL": { "OpenAIImageURL": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -10036,7 +10075,7 @@
{ {
"type": "array", "type": "array",
"items": { "items": {
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam" "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
} }
} }
], ],
@ -10107,7 +10146,7 @@
{ {
"type": "array", "type": "array",
"items": { "items": {
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam" "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
} }
} }
], ],

View file

@ -6895,7 +6895,7 @@ components:
- type: string - type: string
- type: array - type: array
items: items:
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam' $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
description: The content of the model's response description: The content of the model's response
name: name:
type: string type: string
@ -6934,11 +6934,13 @@ components:
oneOf: oneOf:
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam' - $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam' - $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
- $ref: '#/components/schemas/OpenAIFile'
discriminator: discriminator:
propertyName: type propertyName: type
mapping: mapping:
text: '#/components/schemas/OpenAIChatCompletionContentPartTextParam' text: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
image_url: '#/components/schemas/OpenAIChatCompletionContentPartImageParam' image_url: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
file: '#/components/schemas/OpenAIFile'
OpenAIChatCompletionContentPartTextParam: OpenAIChatCompletionContentPartTextParam:
type: object type: object
properties: properties:
@ -7037,7 +7039,7 @@ components:
- type: string - type: string
- type: array - type: array
items: items:
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam' $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
description: The content of the developer message description: The content of the developer message
name: name:
type: string type: string
@ -7050,6 +7052,31 @@ components:
title: OpenAIDeveloperMessageParam title: OpenAIDeveloperMessageParam
description: >- description: >-
A message from the developer in an OpenAI-compatible chat completion request. A message from the developer in an OpenAI-compatible chat completion request.
OpenAIFile:
type: object
properties:
type:
type: string
const: file
default: file
file:
$ref: '#/components/schemas/OpenAIFileFile'
additionalProperties: false
required:
- type
- file
title: OpenAIFile
OpenAIFileFile:
type: object
properties:
file_data:
type: string
file_id:
type: string
filename:
type: string
additionalProperties: false
title: OpenAIFileFile
OpenAIImageURL: OpenAIImageURL:
type: object type: object
properties: properties:
@ -7090,7 +7117,7 @@ components:
- type: string - type: string
- type: array - type: array
items: items:
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam' $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
description: >- description: >-
The content of the "system prompt". If multiple system messages are provided, The content of the "system prompt". If multiple system messages are provided,
they are concatenated. The underlying Llama Stack code may also add other they are concatenated. The underlying Llama Stack code may also add other
@ -7148,7 +7175,7 @@ components:
- type: string - type: string
- type: array - type: array
items: items:
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam' $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
description: The response content from the tool description: The response content from the tool
additionalProperties: false additionalProperties: false
required: required:

View file

@ -13,7 +13,7 @@ llama stack build --template starter --image-type venv
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
client = LlamaStackAsLibraryClient( client = LlamaStackAsLibraryClient(
"ollama", "starter",
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here. # provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]}, provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
) )

View file

@ -12,8 +12,7 @@ To enable external providers, you need to add `module` into your build yaml, all
an example entry in your build.yaml should look like: an example entry in your build.yaml should look like:
``` ```
- provider_id: ramalama - provider_type: remote::ramalama
provider_type: remote::ramalama
module: ramalama_stack module: ramalama_stack
``` ```
@ -255,8 +254,7 @@ distribution_spec:
container_image: null container_image: null
providers: providers:
inference: inference:
- provider_id: ramalama - provider_type: remote::ramalama
provider_type: remote::ramalama
module: ramalama_stack==0.3.0a0 module: ramalama_stack==0.3.0a0
image_type: venv image_type: venv
image_name: null image_name: null

View file

@ -9,11 +9,13 @@ OpenAI inference provider for accessing GPT models and other OpenAI services.
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `api_key` | `str \| None` | No | | API key for OpenAI models | | `api_key` | `str \| None` | No | | API key for OpenAI models |
| `base_url` | `<class 'str'>` | No | https://api.openai.com/v1 | Base URL for OpenAI API |
## Sample Configuration ## Sample Configuration
```yaml ```yaml
api_key: ${env.OPENAI_API_KEY:=} api_key: ${env.OPENAI_API_KEY:=}
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
``` ```

View file

@ -455,8 +455,21 @@ class OpenAIChatCompletionContentPartImageParam(BaseModel):
image_url: OpenAIImageURL image_url: OpenAIImageURL
@json_schema_type
class OpenAIFileFile(BaseModel):
file_data: str | None = None
file_id: str | None = None
filename: str | None = None
@json_schema_type
class OpenAIFile(BaseModel):
type: Literal["file"] = "file"
file: OpenAIFileFile
OpenAIChatCompletionContentPartParam = Annotated[ OpenAIChatCompletionContentPartParam = Annotated[
OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam, OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam | OpenAIFile,
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam") register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
@ -464,6 +477,8 @@ register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletion
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam] OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
OpenAIChatCompletionTextOnlyMessageContent = str | list[OpenAIChatCompletionContentPartTextParam]
@json_schema_type @json_schema_type
class OpenAIUserMessageParam(BaseModel): class OpenAIUserMessageParam(BaseModel):
@ -489,7 +504,7 @@ class OpenAISystemMessageParam(BaseModel):
""" """
role: Literal["system"] = "system" role: Literal["system"] = "system"
content: OpenAIChatCompletionMessageContent content: OpenAIChatCompletionTextOnlyMessageContent
name: str | None = None name: str | None = None
@ -518,7 +533,7 @@ class OpenAIAssistantMessageParam(BaseModel):
""" """
role: Literal["assistant"] = "assistant" role: Literal["assistant"] = "assistant"
content: OpenAIChatCompletionMessageContent | None = None content: OpenAIChatCompletionTextOnlyMessageContent | None = None
name: str | None = None name: str | None = None
tool_calls: list[OpenAIChatCompletionToolCall] | None = None tool_calls: list[OpenAIChatCompletionToolCall] | None = None
@ -534,7 +549,7 @@ class OpenAIToolMessageParam(BaseModel):
role: Literal["tool"] = "tool" role: Literal["tool"] = "tool"
tool_call_id: str tool_call_id: str
content: OpenAIChatCompletionMessageContent content: OpenAIChatCompletionTextOnlyMessageContent
@json_schema_type @json_schema_type
@ -547,7 +562,7 @@ class OpenAIDeveloperMessageParam(BaseModel):
""" """
role: Literal["developer"] = "developer" role: Literal["developer"] = "developer"
content: OpenAIChatCompletionMessageContent content: OpenAIChatCompletionTextOnlyMessageContent
name: str | None = None name: str | None = None

View file

@ -18,10 +18,6 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
# mounting is not supported by docker buildx, so we use COPY instead # mounting is not supported by docker buildx, so we use COPY instead
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-} USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
# Mount command for cache container .cache, can be overridden by the user if needed
MOUNT_CACHE=${MOUNT_CACHE:-"--mount=type=cache,id=llama-stack-cache,target=/root/.cache"}
# Path to the run.yaml file in the container # Path to the run.yaml file in the container
RUN_CONFIG_PATH=/app/run.yaml RUN_CONFIG_PATH=/app/run.yaml
@ -176,18 +172,13 @@ RUN pip install uv
EOF EOF
fi fi
# Set the link mode to copy so that uv doesn't attempt to symlink to the cache directory
add_to_container << EOF
ENV UV_LINK_MODE=copy
EOF
# Add pip dependencies first since llama-stack is what will change most often # Add pip dependencies first since llama-stack is what will change most often
# so we can reuse layers. # so we can reuse layers.
if [ -n "$normal_deps" ]; then if [ -n "$normal_deps" ]; then
read -ra pip_args <<< "$normal_deps" read -ra pip_args <<< "$normal_deps"
quoted_deps=$(printf " %q" "${pip_args[@]}") quoted_deps=$(printf " %q" "${pip_args[@]}")
add_to_container << EOF add_to_container << EOF
RUN $MOUNT_CACHE uv pip install $quoted_deps RUN uv pip install --no-cache $quoted_deps
EOF EOF
fi fi
@ -197,7 +188,7 @@ if [ -n "$optional_deps" ]; then
read -ra pip_args <<< "$part" read -ra pip_args <<< "$part"
quoted_deps=$(printf " %q" "${pip_args[@]}") quoted_deps=$(printf " %q" "${pip_args[@]}")
add_to_container <<EOF add_to_container <<EOF
RUN $MOUNT_CACHE uv pip install $quoted_deps RUN uv pip install --no-cache $quoted_deps
EOF EOF
done done
fi fi
@ -208,10 +199,10 @@ if [ -n "$external_provider_deps" ]; then
read -ra pip_args <<< "$part" read -ra pip_args <<< "$part"
quoted_deps=$(printf " %q" "${pip_args[@]}") quoted_deps=$(printf " %q" "${pip_args[@]}")
add_to_container <<EOF add_to_container <<EOF
RUN $MOUNT_CACHE uv pip install $quoted_deps RUN uv pip install --no-cache $quoted_deps
EOF EOF
add_to_container <<EOF add_to_container <<EOF
RUN python3 - <<PYTHON | $MOUNT_CACHE uv pip install -r - RUN python3 - <<PYTHON | uv pip install --no-cache -r -
import importlib import importlib
import sys import sys
@ -293,7 +284,7 @@ COPY $dir $mount_point
EOF EOF
fi fi
add_to_container << EOF add_to_container << EOF
RUN $MOUNT_CACHE uv pip install -e $mount_point RUN uv pip install --no-cache -e $mount_point
EOF EOF
} }
@ -308,10 +299,10 @@ else
if [ -n "$TEST_PYPI_VERSION" ]; then if [ -n "$TEST_PYPI_VERSION" ]; then
# these packages are damaged in test-pypi, so install them first # these packages are damaged in test-pypi, so install them first
add_to_container << EOF add_to_container << EOF
RUN $MOUNT_CACHE uv pip install fastapi libcst RUN uv pip install --no-cache fastapi libcst
EOF EOF
add_to_container << EOF add_to_container << EOF
RUN $MOUNT_CACHE uv pip install --extra-index-url https://test.pypi.org/simple/ \ RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
--index-strategy unsafe-best-match \ --index-strategy unsafe-best-match \
llama-stack==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION
@ -323,7 +314,7 @@ EOF
SPEC_VERSION="llama-stack" SPEC_VERSION="llama-stack"
fi fi
add_to_container << EOF add_to_container << EOF
RUN $MOUNT_CACHE uv pip install $SPEC_VERSION RUN uv pip install --no-cache $SPEC_VERSION
EOF EOF
fi fi
fi fi

View file

@ -358,7 +358,7 @@ async def shutdown_stack(impls: dict[Api, Any]):
async def refresh_registry_once(impls: dict[Api, Any]): async def refresh_registry_once(impls: dict[Api, Any]):
logger.info("refreshing registry") logger.debug("refreshing registry")
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)] routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
for routing_table in routing_tables: for routing_table in routing_tables:
await routing_table.refresh() await routing_table.refresh()

View file

@ -469,7 +469,7 @@ class HFFinetuningSingleDevice:
use_cpu=True if device.type == "cpu" and not torch.backends.mps.is_available() else False, use_cpu=True if device.type == "cpu" and not torch.backends.mps.is_available() else False,
save_strategy=save_strategy, save_strategy=save_strategy,
report_to="none", report_to="none",
max_seq_length=provider_config.max_seq_length, max_length=provider_config.max_seq_length,
gradient_accumulation_steps=config.gradient_accumulation_steps, gradient_accumulation_steps=config.gradient_accumulation_steps,
gradient_checkpointing=provider_config.gradient_checkpointing, gradient_checkpointing=provider_config.gradient_checkpointing,
learning_rate=lr, learning_rate=lr,

View file

@ -32,7 +32,7 @@ class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
LiteLLMOpenAIMixin.__init__( LiteLLMOpenAIMixin.__init__(
self, self,
model_entries=MODEL_ENTRIES, model_entries=MODEL_ENTRIES,
litellm_provider_name="llama", litellm_provider_name="meta_llama",
api_key_from_config=config.api_key, api_key_from_config=config.api_key,
provider_data_api_key_field="llama_api_key", provider_data_api_key_field="llama_api_key",
openai_compat_api_base=config.openai_compat_api_base, openai_compat_api_base=config.openai_compat_api_base,

View file

@ -24,9 +24,19 @@ class OpenAIConfig(BaseModel):
default=None, default=None,
description="API key for OpenAI models", description="API key for OpenAI models",
) )
base_url: str = Field(
default="https://api.openai.com/v1",
description="Base URL for OpenAI API",
)
@classmethod @classmethod
def sample_run_config(cls, api_key: str = "${env.OPENAI_API_KEY:=}", **kwargs) -> dict[str, Any]: def sample_run_config(
cls,
api_key: str = "${env.OPENAI_API_KEY:=}",
base_url: str = "${env.OPENAI_BASE_URL:=https://api.openai.com/v1}",
**kwargs,
) -> dict[str, Any]:
return { return {
"api_key": api_key, "api_key": api_key,
"base_url": base_url,
} }

View file

@ -65,9 +65,9 @@ class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
""" """
Get the OpenAI API base URL. Get the OpenAI API base URL.
Returns the standard OpenAI API base URL for direct OpenAI API calls. Returns the OpenAI API base URL from the configuration.
""" """
return "https://api.openai.com/v1" return self.config.base_url
async def initialize(self) -> None: async def initialize(self) -> None:
await super().initialize() await super().initialize()

View file

@ -73,6 +73,15 @@ class LiteLLMOpenAIMixin(
provider_data_api_key_field: str, provider_data_api_key_field: str,
openai_compat_api_base: str | None = None, openai_compat_api_base: str | None = None,
): ):
"""
Initialize the LiteLLMOpenAIMixin.
:param model_entries: The model entries to register.
:param api_key_from_config: The API key to use from the config.
:param provider_data_api_key_field: The field in the provider data that contains the API key.
:param litellm_provider_name: The name of the provider, used for model lookups.
:param openai_compat_api_base: The base URL for OpenAI compatibility, or None if not using OpenAI compatibility.
"""
ModelRegistryHelper.__init__(self, model_entries) ModelRegistryHelper.__init__(self, model_entries)
self.litellm_provider_name = litellm_provider_name self.litellm_provider_name = litellm_provider_name
@ -428,3 +437,17 @@ class LiteLLMOpenAIMixin(
logprobs: LogProbConfig | None = None, logprobs: LogProbConfig | None = None,
): ):
raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat") raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat")
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available via LiteLLM for the current
provider (self.litellm_provider_name).
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
if self.litellm_provider_name not in litellm.models_by_provider:
logger.error(f"Provider {self.litellm_provider_name} is not registered in litellm.")
return False
return model in litellm.models_by_provider[self.litellm_provider_name]

View file

@ -56,6 +56,7 @@ providers:
provider_type: remote::openai provider_type: remote::openai
config: config:
api_key: ${env.OPENAI_API_KEY:=} api_key: ${env.OPENAI_API_KEY:=}
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
- provider_id: anthropic - provider_id: anthropic
provider_type: remote::anthropic provider_type: remote::anthropic
config: config:

View file

@ -16,6 +16,7 @@ providers:
provider_type: remote::openai provider_type: remote::openai
config: config:
api_key: ${env.OPENAI_API_KEY:=} api_key: ${env.OPENAI_API_KEY:=}
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
- provider_id: anthropic - provider_id: anthropic
provider_type: remote::anthropic provider_type: remote::anthropic
config: config:

View file

@ -56,6 +56,7 @@ providers:
provider_type: remote::openai provider_type: remote::openai
config: config:
api_key: ${env.OPENAI_API_KEY:=} api_key: ${env.OPENAI_API_KEY:=}
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
- provider_id: anthropic - provider_id: anthropic
provider_type: remote::anthropic provider_type: remote::anthropic
config: config:

View file

@ -20,7 +20,7 @@
"@radix-ui/react-tooltip": "^1.2.6", "@radix-ui/react-tooltip": "^1.2.6",
"class-variance-authority": "^0.7.1", "class-variance-authority": "^0.7.1",
"clsx": "^2.1.1", "clsx": "^2.1.1",
"llama-stack-client": "^0.2.15", "llama-stack-client": ""0.2.16",
"lucide-react": "^0.510.0", "lucide-react": "^0.510.0",
"next": "15.3.3", "next": "15.3.3",
"next-auth": "^4.24.11", "next-auth": "^4.24.11",

View file

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "llama_stack" name = "llama_stack"
version = "0.2.15" version = "0.2.16"
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }] authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
description = "Llama Stack" description = "Llama Stack"
readme = "README.md" readme = "README.md"
@ -28,7 +28,7 @@ dependencies = [
"huggingface-hub>=0.34.0,<1.0", "huggingface-hub>=0.34.0,<1.0",
"jinja2>=3.1.6", "jinja2>=3.1.6",
"jsonschema", "jsonschema",
"llama-stack-client>=0.2.15", "llama-stack-client>=0.2.16",
"llama-api-client>=0.1.2", "llama-api-client>=0.1.2",
"openai>=1.66", "openai>=1.66",
"prompt-toolkit", "prompt-toolkit",
@ -53,7 +53,7 @@ dependencies = [
ui = [ ui = [
"streamlit", "streamlit",
"pandas", "pandas",
"llama-stack-client>=0.2.15", "llama-stack-client>=0.2.16",
"streamlit-option-menu", "streamlit-option-menu",
] ]
@ -114,6 +114,7 @@ test = [
"sqlalchemy[asyncio]>=2.0.41", "sqlalchemy[asyncio]>=2.0.41",
"requests", "requests",
"pymilvus>=2.5.12", "pymilvus>=2.5.12",
"reportlab",
] ]
docs = [ docs = [
"setuptools", "setuptools",

View file

@ -1,269 +0,0 @@
# This file was autogenerated by uv via the following command:
# uv export --frozen --no-hashes --no-emit-project --no-default-groups --output-file=requirements.txt
aiohappyeyeballs==2.5.0
# via aiohttp
aiohttp==3.12.13
# via llama-stack
aiosignal==1.3.2
# via aiohttp
aiosqlite==0.21.0
# via llama-stack
annotated-types==0.7.0
# via pydantic
anyio==4.8.0
# via
# httpx
# llama-api-client
# llama-stack-client
# openai
# starlette
asyncpg==0.30.0
# via llama-stack
attrs==25.1.0
# via
# aiohttp
# jsonschema
# referencing
certifi==2025.1.31
# via
# httpcore
# httpx
# requests
cffi==1.17.1 ; platform_python_implementation != 'PyPy'
# via cryptography
charset-normalizer==3.4.1
# via requests
click==8.1.8
# via
# llama-stack-client
# uvicorn
colorama==0.4.6 ; sys_platform == 'win32'
# via
# click
# tqdm
cryptography==45.0.5
# via python-jose
deprecated==1.2.18
# via
# opentelemetry-api
# opentelemetry-exporter-otlp-proto-http
# opentelemetry-semantic-conventions
distro==1.9.0
# via
# llama-api-client
# llama-stack-client
# openai
ecdsa==0.19.1
# via python-jose
fastapi==0.115.8
# via llama-stack
filelock==3.17.0
# via huggingface-hub
fire==0.7.0
# via
# llama-stack
# llama-stack-client
frozenlist==1.5.0
# via
# aiohttp
# aiosignal
fsspec==2024.12.0
# via huggingface-hub
googleapis-common-protos==1.67.0
# via opentelemetry-exporter-otlp-proto-http
h11==0.16.0
# via
# httpcore
# llama-stack
# uvicorn
hf-xet==1.1.5 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
# via huggingface-hub
httpcore==1.0.9
# via httpx
httpx==0.28.1
# via
# llama-api-client
# llama-stack
# llama-stack-client
# openai
huggingface-hub==0.34.1
# via llama-stack
idna==3.10
# via
# anyio
# httpx
# requests
# yarl
importlib-metadata==8.5.0
# via opentelemetry-api
jinja2==3.1.6
# via llama-stack
jiter==0.8.2
# via openai
jsonschema==4.23.0
# via llama-stack
jsonschema-specifications==2024.10.1
# via jsonschema
llama-api-client==0.1.2
# via llama-stack
llama-stack-client==0.2.15
# via llama-stack
markdown-it-py==3.0.0
# via rich
markupsafe==3.0.2
# via jinja2
mdurl==0.1.2
# via markdown-it-py
multidict==6.1.0
# via
# aiohttp
# yarl
numpy==2.2.3
# via pandas
openai==1.71.0
# via llama-stack
opentelemetry-api==1.30.0
# via
# opentelemetry-exporter-otlp-proto-http
# opentelemetry-sdk
# opentelemetry-semantic-conventions
opentelemetry-exporter-otlp-proto-common==1.30.0
# via opentelemetry-exporter-otlp-proto-http
opentelemetry-exporter-otlp-proto-http==1.30.0
# via llama-stack
opentelemetry-proto==1.30.0
# via
# opentelemetry-exporter-otlp-proto-common
# opentelemetry-exporter-otlp-proto-http
opentelemetry-sdk==1.30.0
# via
# llama-stack
# opentelemetry-exporter-otlp-proto-http
opentelemetry-semantic-conventions==0.51b0
# via opentelemetry-sdk
packaging==24.2
# via huggingface-hub
pandas==2.2.3
# via llama-stack-client
pillow==11.1.0
# via llama-stack
prompt-toolkit==3.0.50
# via
# llama-stack
# llama-stack-client
propcache==0.3.0
# via
# aiohttp
# yarl
protobuf==5.29.5
# via
# googleapis-common-protos
# opentelemetry-proto
pyaml==25.1.0
# via llama-stack-client
pyasn1==0.4.8
# via
# python-jose
# rsa
pycparser==2.22 ; platform_python_implementation != 'PyPy'
# via cffi
pydantic==2.10.6
# via
# fastapi
# llama-api-client
# llama-stack
# llama-stack-client
# openai
pydantic-core==2.27.2
# via pydantic
pygments==2.19.1
# via rich
python-dateutil==2.9.0.post0
# via pandas
python-dotenv==1.0.1
# via llama-stack
python-jose==3.4.0
# via llama-stack
python-multipart==0.0.20
# via llama-stack
pytz==2025.1
# via pandas
pyyaml==6.0.2
# via
# huggingface-hub
# pyaml
referencing==0.36.2
# via
# jsonschema
# jsonschema-specifications
regex==2024.11.6
# via tiktoken
requests==2.32.4
# via
# huggingface-hub
# llama-stack-client
# opentelemetry-exporter-otlp-proto-http
# tiktoken
rich==13.9.4
# via
# llama-stack
# llama-stack-client
rpds-py==0.22.3
# via
# jsonschema
# referencing
rsa==4.9
# via python-jose
six==1.17.0
# via
# ecdsa
# python-dateutil
sniffio==1.3.1
# via
# anyio
# llama-api-client
# llama-stack-client
# openai
starlette==0.45.3
# via
# fastapi
# llama-stack
termcolor==2.5.0
# via
# fire
# llama-stack
# llama-stack-client
tiktoken==0.9.0
# via llama-stack
tqdm==4.67.1
# via
# huggingface-hub
# llama-stack-client
# openai
typing-extensions==4.12.2
# via
# aiosqlite
# anyio
# fastapi
# huggingface-hub
# llama-api-client
# llama-stack-client
# openai
# opentelemetry-sdk
# pydantic
# pydantic-core
# referencing
tzdata==2025.1
# via pandas
urllib3==2.5.0
# via requests
uvicorn==0.34.0
# via llama-stack
wcwidth==0.2.13
# via prompt-toolkit
wrapt==1.17.2
# via deprecated
yarl==1.18.3
# via aiohttp
zipp==3.21.0
# via importlib-metadata

View file

@ -8,6 +8,15 @@
PYTHON_VERSION=${PYTHON_VERSION:-3.12} PYTHON_VERSION=${PYTHON_VERSION:-3.12}
set -e
# Always run this at the end, even if something fails
cleanup() {
echo "Generating coverage report..."
uv run --python "$PYTHON_VERSION" coverage html -d htmlcov-$PYTHON_VERSION
}
trap cleanup EXIT
command -v uv >/dev/null 2>&1 || { echo >&2 "uv is required but it's not installed. Exiting."; exit 1; } command -v uv >/dev/null 2>&1 || { echo >&2 "uv is required but it's not installed. Exiting."; exit 1; }
uv python find "$PYTHON_VERSION" uv python find "$PYTHON_VERSION"
@ -19,6 +28,3 @@ fi
# Run unit tests with coverage # Run unit tests with coverage
uv run --python "$PYTHON_VERSION" --with-editable . --group unit \ uv run --python "$PYTHON_VERSION" --with-editable . --group unit \
coverage run --source=llama_stack -m pytest -s -v tests/unit/ "$@" coverage run --source=llama_stack -m pytest -s -v tests/unit/ "$@"
# Generate HTML coverage report
uv run --python "$PYTHON_VERSION" coverage html -d htmlcov-$PYTHON_VERSION

View file

@ -3,8 +3,7 @@ distribution_spec:
description: Custom distro for CI tests description: Custom distro for CI tests
providers: providers:
weather: weather:
- provider_id: kaze - provider_type: remote::kaze
provider_type: remote::kaze
image_type: venv image_type: venv
image_name: ci-test image_name: ci-test
external_providers_dir: ~/.llama/providers.d external_providers_dir: ~/.llama/providers.d

View file

@ -4,8 +4,7 @@ distribution_spec:
container_image: null container_image: null
providers: providers:
inference: inference:
- provider_id: ramalama - provider_type: remote::ramalama
provider_type: remote::ramalama
module: ramalama_stack==0.3.0a0 module: ramalama_stack==0.3.0a0
image_type: venv image_type: venv
image_name: ramalama-stack-test image_name: ramalama-stack-test

View file

@ -5,8 +5,14 @@
# the root directory of this source tree. # the root directory of this source tree.
import base64
import os
import tempfile
import pytest import pytest
from openai import OpenAI from openai import OpenAI
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
@ -82,6 +88,14 @@ def skip_if_provider_isnt_vllm(client_with_models, model_id):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support vllm extra_body parameters.") pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support vllm extra_body parameters.")
def skip_if_provider_isnt_openai(client_with_models, model_id):
provider = provider_from_model(client_with_models, model_id)
if provider.provider_type != "remote::openai":
pytest.skip(
f"Model {model_id} hosted by {provider.provider_type} doesn't support chat completion calls with base64 encoded files."
)
@pytest.fixture @pytest.fixture
def openai_client(client_with_models): def openai_client(client_with_models):
base_url = f"{client_with_models.base_url}/v1/openai/v1" base_url = f"{client_with_models.base_url}/v1/openai/v1"
@ -418,3 +432,45 @@ def test_inference_store_tool_calls(compat_client, client_with_models, text_mode
# failed tool call parses show up as a message with content, so ensure # failed tool call parses show up as a message with content, so ensure
# that the retrieve response content matches the original request # that the retrieve response content matches the original request
assert retrieved_response.choices[0].message.content == content assert retrieved_response.choices[0].message.content == content
def test_openai_chat_completion_non_streaming_with_file(openai_client, client_with_models, text_model_id):
skip_if_provider_isnt_openai(client_with_models, text_model_id)
# Generate temporary PDF with "Hello World" text
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as temp_pdf:
c = canvas.Canvas(temp_pdf.name, pagesize=letter)
c.drawString(100, 750, "Hello World")
c.save()
# Read the PDF and sencode to base64
with open(temp_pdf.name, "rb") as pdf_file:
pdf_base64 = base64.b64encode(pdf_file.read()).decode("utf-8")
# Clean up temporary file
os.unlink(temp_pdf.name)
response = openai_client.chat.completions.create(
model=text_model_id,
messages=[
{
"role": "user",
"content": "Describe what you see in this PDF file.",
},
{
"role": "user",
"content": [
{
"type": "file",
"file": {
"filename": "my-temp-hello-world-pdf",
"file_data": f"data:application/pdf;base64,{pdf_base64}",
},
}
],
},
],
stream=False,
)
message_content = response.choices[0].message.content.lower().strip()
assert "hello world" in message_content

View file

@ -38,9 +38,8 @@ sys.stdout.reconfigure(line_buffering=True)
# How to run this test: # How to run this test:
# #
# pytest llama_stack/providers/tests/post_training/test_post_training.py # LLAMA_STACK_CONFIG=ci-tests uv run --dev pytest tests/integration/post_training/test_post_training.py
# -m "torchtune_post_training_huggingface_datasetio" #
# -v -s --tb=short --disable-warnings
class TestPostTraining: class TestPostTraining:
@ -113,6 +112,7 @@ class TestPostTraining:
break break
logger.info(f"Current status: {status}") logger.info(f"Current status: {status}")
assert status.status in ["scheduled", "in_progress", "completed"]
if status.status == "completed": if status.status == "completed":
break break

View file

@ -346,7 +346,7 @@ pip_packages:
def test_external_provider_from_module_building(self, mock_providers): def test_external_provider_from_module_building(self, mock_providers):
"""Test loading an external provider from a module during build (building=True, partial spec).""" """Test loading an external provider from a module during build (building=True, partial spec)."""
from llama_stack.distribution.datatypes import BuildConfig, DistributionSpec, Provider from llama_stack.distribution.datatypes import BuildConfig, BuildProvider, DistributionSpec
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
# No importlib patch needed, should not import module when type of `config` is BuildConfig or DistributionSpec # No importlib patch needed, should not import module when type of `config` is BuildConfig or DistributionSpec
@ -358,10 +358,8 @@ pip_packages:
description="test", description="test",
providers={ providers={
"inference": [ "inference": [
Provider( BuildProvider(
provider_id="external_test",
provider_type="external_test", provider_type="external_test",
config={},
module="external_test", module="external_test",
) )
] ]

View file

@ -0,0 +1,125 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from unittest.mock import AsyncMock, MagicMock, patch
from llama_stack.distribution.stack import replace_env_vars
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
class TestOpenAIBaseURLConfig:
"""Test that OPENAI_BASE_URL environment variable properly configures the OpenAI adapter."""
def test_default_base_url_without_env_var(self):
"""Test that the adapter uses the default OpenAI base URL when no environment variable is set."""
config = OpenAIConfig(api_key="test-key")
adapter = OpenAIInferenceAdapter(config)
assert adapter.get_base_url() == "https://api.openai.com/v1"
def test_custom_base_url_from_config(self):
"""Test that the adapter uses a custom base URL when provided in config."""
custom_url = "https://custom.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
assert adapter.get_base_url() == custom_url
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://env.openai.com/v1"})
def test_base_url_from_environment_variable(self):
"""Test that the adapter uses base URL from OPENAI_BASE_URL environment variable."""
# Use sample_run_config which has proper environment variable syntax
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
processed_config = replace_env_vars(config_data)
config = OpenAIConfig.model_validate(processed_config)
adapter = OpenAIInferenceAdapter(config)
assert adapter.get_base_url() == "https://env.openai.com/v1"
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://env.openai.com/v1"})
def test_config_overrides_environment_variable(self):
"""Test that explicit config value overrides environment variable."""
custom_url = "https://config.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
# Config should take precedence over environment variable
assert adapter.get_base_url() == custom_url
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
def test_client_uses_configured_base_url(self, mock_openai_class):
"""Test that the OpenAI client is initialized with the configured base URL."""
custom_url = "https://test.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
# Mock the get_api_key method since it's delegated to LiteLLMOpenAIMixin
adapter.get_api_key = MagicMock(return_value="test-key")
# Access the client property to trigger AsyncOpenAI initialization
_ = adapter.client
# Verify AsyncOpenAI was called with the correct base_url
mock_openai_class.assert_called_once_with(
api_key="test-key",
base_url=custom_url,
)
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
async def test_check_model_availability_uses_configured_url(self, mock_openai_class):
"""Test that check_model_availability uses the configured base URL."""
custom_url = "https://test.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
# Mock the get_api_key method
adapter.get_api_key = MagicMock(return_value="test-key")
# Mock the AsyncOpenAI client and its models.retrieve method
mock_client = MagicMock()
mock_client.models.retrieve = AsyncMock(return_value=MagicMock())
mock_openai_class.return_value = mock_client
# Call check_model_availability and verify it returns True
assert await adapter.check_model_availability("gpt-4")
# Verify the client was created with the custom URL
mock_openai_class.assert_called_with(
api_key="test-key",
base_url=custom_url,
)
# Verify the method was called and returned True
mock_client.models.retrieve.assert_called_once_with("gpt-4")
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://proxy.openai.com/v1"})
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
async def test_environment_variable_affects_model_availability_check(self, mock_openai_class):
"""Test that setting OPENAI_BASE_URL environment variable affects where model availability is checked."""
# Use sample_run_config which has proper environment variable syntax
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
processed_config = replace_env_vars(config_data)
config = OpenAIConfig.model_validate(processed_config)
adapter = OpenAIInferenceAdapter(config)
# Mock the get_api_key method
adapter.get_api_key = MagicMock(return_value="test-key")
# Mock the AsyncOpenAI client
mock_client = MagicMock()
mock_client.models.retrieve = AsyncMock(return_value=MagicMock())
mock_openai_class.return_value = mock_client
# Call check_model_availability and verify it returns True
assert await adapter.check_model_availability("gpt-4")
# Verify the client was created with the environment variable URL
mock_openai_class.assert_called_with(
api_key="test-key",
base_url="https://proxy.openai.com/v1",
)

View file

@ -4,13 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import pytest
from pydantic import ValidationError
from llama_stack.apis.common.content_types import TextContentItem from llama_stack.apis.common.content_types import TextContentItem
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
CompletionMessage, CompletionMessage,
OpenAIAssistantMessageParam, OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam, OpenAIChatCompletionContentPartTextParam,
OpenAIDeveloperMessageParam,
OpenAIImageURL,
OpenAISystemMessageParam, OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam, OpenAIUserMessageParam,
SystemMessage, SystemMessage,
UserMessage, UserMessage,
@ -108,3 +114,71 @@ async def test_openai_messages_to_messages_with_content_list():
assert llama_messages[0].content[0].text == "system message" assert llama_messages[0].content[0].text == "system message"
assert llama_messages[1].content[0].text == "user message" assert llama_messages[1].content[0].text == "user message"
assert llama_messages[2].content[0].text == "assistant message" assert llama_messages[2].content[0].text == "assistant message"
@pytest.mark.parametrize(
"message_class,kwargs",
[
(OpenAISystemMessageParam, {}),
(OpenAIAssistantMessageParam, {}),
(OpenAIDeveloperMessageParam, {}),
(OpenAIUserMessageParam, {}),
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
],
)
def test_message_accepts_text_string(message_class, kwargs):
"""Test that messages accept string text content."""
msg = message_class(content="Test message", **kwargs)
assert msg.content == "Test message"
@pytest.mark.parametrize(
"message_class,kwargs",
[
(OpenAISystemMessageParam, {}),
(OpenAIAssistantMessageParam, {}),
(OpenAIDeveloperMessageParam, {}),
(OpenAIUserMessageParam, {}),
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
],
)
def test_message_accepts_text_list(message_class, kwargs):
"""Test that messages accept list of text content parts."""
content_list = [OpenAIChatCompletionContentPartTextParam(text="Test message")]
msg = message_class(content=content_list, **kwargs)
assert len(msg.content) == 1
assert msg.content[0].text == "Test message"
@pytest.mark.parametrize(
"message_class,kwargs",
[
(OpenAISystemMessageParam, {}),
(OpenAIAssistantMessageParam, {}),
(OpenAIDeveloperMessageParam, {}),
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
],
)
def test_message_rejects_images(message_class, kwargs):
"""Test that system, assistant, developer, and tool messages reject image content."""
with pytest.raises(ValidationError):
message_class(
content=[
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg"))
],
**kwargs,
)
def test_user_message_accepts_images():
"""Test that user messages accept image content (unlike other message types)."""
# List with images should work
msg = OpenAIUserMessageParam(
content=[
OpenAIChatCompletionContentPartTextParam(text="Describe this image:"),
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")),
]
)
assert len(msg.content) == 2
assert msg.content[0].text == "Describe this image:"
assert msg.content[1].image_url.url == "http://example.com/image.jpg"

View file

@ -162,26 +162,29 @@ async def test_register_model_existing_different(
await helper.register_model(known_model) await helper.register_model(known_model)
async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model) -> None: # TODO: unregister_model functionality was removed/disabled by https://github.com/meta-llama/llama-stack/pull/2916
await helper.register_model(known_model) # duplicate entry # async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model) -> None:
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_model_id # await helper.register_model(known_model) # duplicate entry
await helper.unregister_model(known_model.model_id) # assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_model_id
assert helper.get_provider_model_id(known_model.model_id) is None # await helper.unregister_model(known_model.model_id)
# assert helper.get_provider_model_id(known_model.model_id) is None
async def test_unregister_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None: # TODO: unregister_model functionality was removed/disabled by https://github.com/meta-llama/llama-stack/pull/2916
with pytest.raises(ValueError): # async def test_unregister_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
await helper.unregister_model(unknown_model.model_id) # with pytest.raises(ValueError):
# await helper.unregister_model(unknown_model.model_id)
async def test_register_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None: async def test_register_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None: # TODO: unregister_model functionality was removed/disabled by https://github.com/meta-llama/llama-stack/pull/2916
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id # async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
await helper.unregister_model(known_model.provider_resource_id) # assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
assert helper.get_provider_model_id(known_model.provider_resource_id) is None # await helper.unregister_model(known_model.provider_resource_id)
# assert helper.get_provider_model_id(known_model.provider_resource_id) is None
async def test_register_model_from_check_model_availability( async def test_register_model_from_check_model_availability(

View file

@ -49,7 +49,7 @@ def github_token_app():
) )
# Add auth middleware # Add auth middleware
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
@app.get("/test") @app.get("/test")
def test_endpoint(): def test_endpoint():
@ -149,7 +149,7 @@ def test_github_enterprise_support(mock_client_class):
access_policy=[], access_policy=[],
) )
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
@app.get("/test") @app.get("/test")
def test_endpoint(): def test_endpoint():

2668
uv.lock generated

File diff suppressed because it is too large Load diff