mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 14:08:00 +00:00
Merge branch 'main' into patch-1
This commit is contained in:
commit
92c2edd61c
35 changed files with 1916 additions and 1589 deletions
38
.github/workflows/pre-commit.yml
vendored
38
.github/workflows/pre-commit.yml
vendored
|
@ -14,10 +14,18 @@ concurrency:
|
|||
jobs:
|
||||
pre-commit:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
# For dependabot PRs, we need to checkout with a token that can push changes
|
||||
token: ${{ github.actor == 'dependabot[bot]' && secrets.GITHUB_TOKEN || github.token }}
|
||||
# Fetch full history for dependabot PRs to allow commits
|
||||
fetch-depth: ${{ github.actor == 'dependabot[bot]' && 0 || 1 }}
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||
|
@ -29,15 +37,45 @@ jobs:
|
|||
.pre-commit-config.yaml
|
||||
|
||||
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
||||
continue-on-error: true
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
RUFF_OUTPUT_FORMAT: github
|
||||
|
||||
- name: Debug
|
||||
run: |
|
||||
echo "github.ref: ${{ github.ref }}"
|
||||
echo "github.actor: ${{ github.actor }}"
|
||||
|
||||
- name: Commit changes for dependabot PRs
|
||||
if: github.actor == 'dependabot[bot]'
|
||||
run: |
|
||||
if ! git diff --exit-code || [ -n "$(git ls-files --others --exclude-standard)" ]; then
|
||||
git config --local user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git config --local user.name "github-actions[bot]"
|
||||
|
||||
# Ensure we're on the correct branch
|
||||
git checkout -B ${{ github.head_ref }}
|
||||
git add -A
|
||||
git commit -m "Apply pre-commit fixes"
|
||||
|
||||
# Pull latest changes from the PR branch and rebase our commit on top
|
||||
git pull --rebase origin ${{ github.head_ref }}
|
||||
|
||||
# Push to the PR branch
|
||||
git push origin ${{ github.head_ref }}
|
||||
echo "Pre-commit fixes committed and pushed"
|
||||
else
|
||||
echo "No changes to commit"
|
||||
fi
|
||||
|
||||
- name: Verify if there are any diff files after pre-commit
|
||||
if: github.actor != 'dependabot[bot]'
|
||||
run: |
|
||||
git diff --exit-code || (echo "There are uncommitted changes, run pre-commit locally and commit again" && exit 1)
|
||||
|
||||
- name: Verify if there are any new files after pre-commit
|
||||
if: github.actor != 'dependabot[bot]'
|
||||
run: |
|
||||
unstaged_files=$(git ls-files --others --exclude-standard)
|
||||
if [ -n "$unstaged_files" ]; then
|
||||
|
|
|
@ -13,6 +13,7 @@ on:
|
|||
- 'uv.lock'
|
||||
- 'pyproject.toml'
|
||||
- 'requirements.txt'
|
||||
- 'tests/external/*'
|
||||
- '.github/workflows/test-external-provider-module.yml' # This workflow
|
||||
|
||||
jobs:
|
||||
|
|
1
.github/workflows/test-external.yml
vendored
1
.github/workflows/test-external.yml
vendored
|
@ -13,6 +13,7 @@ on:
|
|||
- 'uv.lock'
|
||||
- 'pyproject.toml'
|
||||
- 'requirements.txt'
|
||||
- 'tests/external/*'
|
||||
- '.github/workflows/test-external.yml' # This workflow
|
||||
|
||||
jobs:
|
||||
|
|
2
.github/workflows/unit-tests.yml
vendored
2
.github/workflows/unit-tests.yml
vendored
|
@ -35,6 +35,8 @@ jobs:
|
|||
|
||||
- name: Install dependencies
|
||||
uses: ./.github/actions/setup-runner
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
|
||||
- name: Run unit tests
|
||||
run: |
|
||||
|
|
|
@ -19,7 +19,6 @@ repos:
|
|||
- id: check-yaml
|
||||
args: ["--unsafe"]
|
||||
- id: detect-private-key
|
||||
- id: requirements-txt-fixer
|
||||
- id: mixed-line-ending
|
||||
args: [--fix=lf] # Forces to replace line ending by LF (line feed)
|
||||
- id: check-executables-have-shebangs
|
||||
|
@ -56,14 +55,6 @@ repos:
|
|||
rev: 0.7.20
|
||||
hooks:
|
||||
- id: uv-lock
|
||||
- id: uv-export
|
||||
args: [
|
||||
"--frozen",
|
||||
"--no-hashes",
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--output-file=requirements.txt"
|
||||
]
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.16.1
|
||||
|
|
49
docs/_static/llama-stack-spec.html
vendored
49
docs/_static/llama-stack-spec.html
vendored
|
@ -9770,7 +9770,7 @@
|
|||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
|
||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
|
||||
}
|
||||
}
|
||||
],
|
||||
|
@ -9821,13 +9821,17 @@
|
|||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIFile"
|
||||
}
|
||||
],
|
||||
"discriminator": {
|
||||
"propertyName": "type",
|
||||
"mapping": {
|
||||
"text": "#/components/schemas/OpenAIChatCompletionContentPartTextParam",
|
||||
"image_url": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
|
||||
"image_url": "#/components/schemas/OpenAIChatCompletionContentPartImageParam",
|
||||
"file": "#/components/schemas/OpenAIFile"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -9955,7 +9959,7 @@
|
|||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
|
||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
|
||||
}
|
||||
}
|
||||
],
|
||||
|
@ -9974,6 +9978,41 @@
|
|||
"title": "OpenAIDeveloperMessageParam",
|
||||
"description": "A message from the developer in an OpenAI-compatible chat completion request."
|
||||
},
|
||||
"OpenAIFile": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "file",
|
||||
"default": "file"
|
||||
},
|
||||
"file": {
|
||||
"$ref": "#/components/schemas/OpenAIFileFile"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"file"
|
||||
],
|
||||
"title": "OpenAIFile"
|
||||
},
|
||||
"OpenAIFileFile": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_data": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"filename": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"title": "OpenAIFileFile"
|
||||
},
|
||||
"OpenAIImageURL": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -10036,7 +10075,7 @@
|
|||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
|
||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
|
||||
}
|
||||
}
|
||||
],
|
||||
|
@ -10107,7 +10146,7 @@
|
|||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
|
||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
|
||||
}
|
||||
}
|
||||
],
|
||||
|
|
35
docs/_static/llama-stack-spec.yaml
vendored
35
docs/_static/llama-stack-spec.yaml
vendored
|
@ -6895,7 +6895,7 @@ components:
|
|||
- type: string
|
||||
- type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
|
||||
$ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
||||
description: The content of the model's response
|
||||
name:
|
||||
type: string
|
||||
|
@ -6934,11 +6934,13 @@ components:
|
|||
oneOf:
|
||||
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
||||
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
|
||||
- $ref: '#/components/schemas/OpenAIFile'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
text: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
||||
image_url: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
|
||||
file: '#/components/schemas/OpenAIFile'
|
||||
OpenAIChatCompletionContentPartTextParam:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -7037,7 +7039,7 @@ components:
|
|||
- type: string
|
||||
- type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
|
||||
$ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
||||
description: The content of the developer message
|
||||
name:
|
||||
type: string
|
||||
|
@ -7050,6 +7052,31 @@ components:
|
|||
title: OpenAIDeveloperMessageParam
|
||||
description: >-
|
||||
A message from the developer in an OpenAI-compatible chat completion request.
|
||||
OpenAIFile:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: file
|
||||
default: file
|
||||
file:
|
||||
$ref: '#/components/schemas/OpenAIFileFile'
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- file
|
||||
title: OpenAIFile
|
||||
OpenAIFileFile:
|
||||
type: object
|
||||
properties:
|
||||
file_data:
|
||||
type: string
|
||||
file_id:
|
||||
type: string
|
||||
filename:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
title: OpenAIFileFile
|
||||
OpenAIImageURL:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -7090,7 +7117,7 @@ components:
|
|||
- type: string
|
||||
- type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
|
||||
$ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
||||
description: >-
|
||||
The content of the "system prompt". If multiple system messages are provided,
|
||||
they are concatenated. The underlying Llama Stack code may also add other
|
||||
|
@ -7148,7 +7175,7 @@ components:
|
|||
- type: string
|
||||
- type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
|
||||
$ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
||||
description: The response content from the tool
|
||||
additionalProperties: false
|
||||
required:
|
||||
|
|
|
@ -13,7 +13,7 @@ llama stack build --template starter --image-type venv
|
|||
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||
|
||||
client = LlamaStackAsLibraryClient(
|
||||
"ollama",
|
||||
"starter",
|
||||
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
|
||||
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
|
||||
)
|
||||
|
|
|
@ -12,8 +12,7 @@ To enable external providers, you need to add `module` into your build yaml, all
|
|||
an example entry in your build.yaml should look like:
|
||||
|
||||
```
|
||||
- provider_id: ramalama
|
||||
provider_type: remote::ramalama
|
||||
- provider_type: remote::ramalama
|
||||
module: ramalama_stack
|
||||
```
|
||||
|
||||
|
@ -255,8 +254,7 @@ distribution_spec:
|
|||
container_image: null
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: ramalama
|
||||
provider_type: remote::ramalama
|
||||
- provider_type: remote::ramalama
|
||||
module: ramalama_stack==0.3.0a0
|
||||
image_type: venv
|
||||
image_name: null
|
||||
|
|
|
@ -9,11 +9,13 @@ OpenAI inference provider for accessing GPT models and other OpenAI services.
|
|||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `api_key` | `str \| None` | No | | API key for OpenAI models |
|
||||
| `base_url` | `<class 'str'>` | No | https://api.openai.com/v1 | Base URL for OpenAI API |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
```yaml
|
||||
api_key: ${env.OPENAI_API_KEY:=}
|
||||
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -455,8 +455,21 @@ class OpenAIChatCompletionContentPartImageParam(BaseModel):
|
|||
image_url: OpenAIImageURL
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIFileFile(BaseModel):
|
||||
file_data: str | None = None
|
||||
file_id: str | None = None
|
||||
filename: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIFile(BaseModel):
|
||||
type: Literal["file"] = "file"
|
||||
file: OpenAIFileFile
|
||||
|
||||
|
||||
OpenAIChatCompletionContentPartParam = Annotated[
|
||||
OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam | OpenAIFile,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
|
||||
|
@ -464,6 +477,8 @@ register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletion
|
|||
|
||||
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
|
||||
|
||||
OpenAIChatCompletionTextOnlyMessageContent = str | list[OpenAIChatCompletionContentPartTextParam]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIUserMessageParam(BaseModel):
|
||||
|
@ -489,7 +504,7 @@ class OpenAISystemMessageParam(BaseModel):
|
|||
"""
|
||||
|
||||
role: Literal["system"] = "system"
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
content: OpenAIChatCompletionTextOnlyMessageContent
|
||||
name: str | None = None
|
||||
|
||||
|
||||
|
@ -518,7 +533,7 @@ class OpenAIAssistantMessageParam(BaseModel):
|
|||
"""
|
||||
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: OpenAIChatCompletionMessageContent | None = None
|
||||
content: OpenAIChatCompletionTextOnlyMessageContent | None = None
|
||||
name: str | None = None
|
||||
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
|
||||
|
||||
|
@ -534,7 +549,7 @@ class OpenAIToolMessageParam(BaseModel):
|
|||
|
||||
role: Literal["tool"] = "tool"
|
||||
tool_call_id: str
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
content: OpenAIChatCompletionTextOnlyMessageContent
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -547,7 +562,7 @@ class OpenAIDeveloperMessageParam(BaseModel):
|
|||
"""
|
||||
|
||||
role: Literal["developer"] = "developer"
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
content: OpenAIChatCompletionTextOnlyMessageContent
|
||||
name: str | None = None
|
||||
|
||||
|
||||
|
|
|
@ -18,10 +18,6 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
|||
|
||||
# mounting is not supported by docker buildx, so we use COPY instead
|
||||
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
|
||||
|
||||
# Mount command for cache container .cache, can be overridden by the user if needed
|
||||
MOUNT_CACHE=${MOUNT_CACHE:-"--mount=type=cache,id=llama-stack-cache,target=/root/.cache"}
|
||||
|
||||
# Path to the run.yaml file in the container
|
||||
RUN_CONFIG_PATH=/app/run.yaml
|
||||
|
||||
|
@ -176,18 +172,13 @@ RUN pip install uv
|
|||
EOF
|
||||
fi
|
||||
|
||||
# Set the link mode to copy so that uv doesn't attempt to symlink to the cache directory
|
||||
add_to_container << EOF
|
||||
ENV UV_LINK_MODE=copy
|
||||
EOF
|
||||
|
||||
# Add pip dependencies first since llama-stack is what will change most often
|
||||
# so we can reuse layers.
|
||||
if [ -n "$normal_deps" ]; then
|
||||
read -ra pip_args <<< "$normal_deps"
|
||||
quoted_deps=$(printf " %q" "${pip_args[@]}")
|
||||
add_to_container << EOF
|
||||
RUN $MOUNT_CACHE uv pip install $quoted_deps
|
||||
RUN uv pip install --no-cache $quoted_deps
|
||||
EOF
|
||||
fi
|
||||
|
||||
|
@ -197,7 +188,7 @@ if [ -n "$optional_deps" ]; then
|
|||
read -ra pip_args <<< "$part"
|
||||
quoted_deps=$(printf " %q" "${pip_args[@]}")
|
||||
add_to_container <<EOF
|
||||
RUN $MOUNT_CACHE uv pip install $quoted_deps
|
||||
RUN uv pip install --no-cache $quoted_deps
|
||||
EOF
|
||||
done
|
||||
fi
|
||||
|
@ -208,10 +199,10 @@ if [ -n "$external_provider_deps" ]; then
|
|||
read -ra pip_args <<< "$part"
|
||||
quoted_deps=$(printf " %q" "${pip_args[@]}")
|
||||
add_to_container <<EOF
|
||||
RUN $MOUNT_CACHE uv pip install $quoted_deps
|
||||
RUN uv pip install --no-cache $quoted_deps
|
||||
EOF
|
||||
add_to_container <<EOF
|
||||
RUN python3 - <<PYTHON | $MOUNT_CACHE uv pip install -r -
|
||||
RUN python3 - <<PYTHON | uv pip install --no-cache -r -
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
|
@ -293,7 +284,7 @@ COPY $dir $mount_point
|
|||
EOF
|
||||
fi
|
||||
add_to_container << EOF
|
||||
RUN $MOUNT_CACHE uv pip install -e $mount_point
|
||||
RUN uv pip install --no-cache -e $mount_point
|
||||
EOF
|
||||
}
|
||||
|
||||
|
@ -308,10 +299,10 @@ else
|
|||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
# these packages are damaged in test-pypi, so install them first
|
||||
add_to_container << EOF
|
||||
RUN $MOUNT_CACHE uv pip install fastapi libcst
|
||||
RUN uv pip install --no-cache fastapi libcst
|
||||
EOF
|
||||
add_to_container << EOF
|
||||
RUN $MOUNT_CACHE uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
||||
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
|
||||
--index-strategy unsafe-best-match \
|
||||
llama-stack==$TEST_PYPI_VERSION
|
||||
|
||||
|
@ -323,7 +314,7 @@ EOF
|
|||
SPEC_VERSION="llama-stack"
|
||||
fi
|
||||
add_to_container << EOF
|
||||
RUN $MOUNT_CACHE uv pip install $SPEC_VERSION
|
||||
RUN uv pip install --no-cache $SPEC_VERSION
|
||||
EOF
|
||||
fi
|
||||
fi
|
||||
|
|
|
@ -358,7 +358,7 @@ async def shutdown_stack(impls: dict[Api, Any]):
|
|||
|
||||
|
||||
async def refresh_registry_once(impls: dict[Api, Any]):
|
||||
logger.info("refreshing registry")
|
||||
logger.debug("refreshing registry")
|
||||
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
|
||||
for routing_table in routing_tables:
|
||||
await routing_table.refresh()
|
||||
|
|
|
@ -469,7 +469,7 @@ class HFFinetuningSingleDevice:
|
|||
use_cpu=True if device.type == "cpu" and not torch.backends.mps.is_available() else False,
|
||||
save_strategy=save_strategy,
|
||||
report_to="none",
|
||||
max_seq_length=provider_config.max_seq_length,
|
||||
max_length=provider_config.max_seq_length,
|
||||
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
||||
gradient_checkpointing=provider_config.gradient_checkpointing,
|
||||
learning_rate=lr,
|
||||
|
|
|
@ -32,7 +32,7 @@ class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
litellm_provider_name="llama",
|
||||
litellm_provider_name="meta_llama",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="llama_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
|
|
|
@ -24,9 +24,19 @@ class OpenAIConfig(BaseModel):
|
|||
default=None,
|
||||
description="API key for OpenAI models",
|
||||
)
|
||||
base_url: str = Field(
|
||||
default="https://api.openai.com/v1",
|
||||
description="Base URL for OpenAI API",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.OPENAI_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(
|
||||
cls,
|
||||
api_key: str = "${env.OPENAI_API_KEY:=}",
|
||||
base_url: str = "${env.OPENAI_BASE_URL:=https://api.openai.com/v1}",
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
"base_url": base_url,
|
||||
}
|
||||
|
|
|
@ -65,9 +65,9 @@ class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
"""
|
||||
Get the OpenAI API base URL.
|
||||
|
||||
Returns the standard OpenAI API base URL for direct OpenAI API calls.
|
||||
Returns the OpenAI API base URL from the configuration.
|
||||
"""
|
||||
return "https://api.openai.com/v1"
|
||||
return self.config.base_url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await super().initialize()
|
||||
|
|
|
@ -73,6 +73,15 @@ class LiteLLMOpenAIMixin(
|
|||
provider_data_api_key_field: str,
|
||||
openai_compat_api_base: str | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the LiteLLMOpenAIMixin.
|
||||
|
||||
:param model_entries: The model entries to register.
|
||||
:param api_key_from_config: The API key to use from the config.
|
||||
:param provider_data_api_key_field: The field in the provider data that contains the API key.
|
||||
:param litellm_provider_name: The name of the provider, used for model lookups.
|
||||
:param openai_compat_api_base: The base URL for OpenAI compatibility, or None if not using OpenAI compatibility.
|
||||
"""
|
||||
ModelRegistryHelper.__init__(self, model_entries)
|
||||
|
||||
self.litellm_provider_name = litellm_provider_name
|
||||
|
@ -428,3 +437,17 @@ class LiteLLMOpenAIMixin(
|
|||
logprobs: LogProbConfig | None = None,
|
||||
):
|
||||
raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat")
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
Check if a specific model is available via LiteLLM for the current
|
||||
provider (self.litellm_provider_name).
|
||||
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically, False otherwise.
|
||||
"""
|
||||
if self.litellm_provider_name not in litellm.models_by_provider:
|
||||
logger.error(f"Provider {self.litellm_provider_name} is not registered in litellm.")
|
||||
return False
|
||||
|
||||
return model in litellm.models_by_provider[self.litellm_provider_name]
|
||||
|
|
|
@ -56,6 +56,7 @@ providers:
|
|||
provider_type: remote::openai
|
||||
config:
|
||||
api_key: ${env.OPENAI_API_KEY:=}
|
||||
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
|
||||
- provider_id: anthropic
|
||||
provider_type: remote::anthropic
|
||||
config:
|
||||
|
|
|
@ -16,6 +16,7 @@ providers:
|
|||
provider_type: remote::openai
|
||||
config:
|
||||
api_key: ${env.OPENAI_API_KEY:=}
|
||||
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
|
||||
- provider_id: anthropic
|
||||
provider_type: remote::anthropic
|
||||
config:
|
||||
|
|
|
@ -56,6 +56,7 @@ providers:
|
|||
provider_type: remote::openai
|
||||
config:
|
||||
api_key: ${env.OPENAI_API_KEY:=}
|
||||
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
|
||||
- provider_id: anthropic
|
||||
provider_type: remote::anthropic
|
||||
config:
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
"@radix-ui/react-tooltip": "^1.2.6",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"llama-stack-client": "^0.2.15",
|
||||
"llama-stack-client": ""0.2.16",
|
||||
"lucide-react": "^0.510.0",
|
||||
"next": "15.3.3",
|
||||
"next-auth": "^4.24.11",
|
||||
|
|
|
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||
|
||||
[project]
|
||||
name = "llama_stack"
|
||||
version = "0.2.15"
|
||||
version = "0.2.16"
|
||||
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
|
||||
description = "Llama Stack"
|
||||
readme = "README.md"
|
||||
|
@ -28,7 +28,7 @@ dependencies = [
|
|||
"huggingface-hub>=0.34.0,<1.0",
|
||||
"jinja2>=3.1.6",
|
||||
"jsonschema",
|
||||
"llama-stack-client>=0.2.15",
|
||||
"llama-stack-client>=0.2.16",
|
||||
"llama-api-client>=0.1.2",
|
||||
"openai>=1.66",
|
||||
"prompt-toolkit",
|
||||
|
@ -53,7 +53,7 @@ dependencies = [
|
|||
ui = [
|
||||
"streamlit",
|
||||
"pandas",
|
||||
"llama-stack-client>=0.2.15",
|
||||
"llama-stack-client>=0.2.16",
|
||||
"streamlit-option-menu",
|
||||
]
|
||||
|
||||
|
@ -114,6 +114,7 @@ test = [
|
|||
"sqlalchemy[asyncio]>=2.0.41",
|
||||
"requests",
|
||||
"pymilvus>=2.5.12",
|
||||
"reportlab",
|
||||
]
|
||||
docs = [
|
||||
"setuptools",
|
||||
|
|
269
requirements.txt
269
requirements.txt
|
@ -1,269 +0,0 @@
|
|||
# This file was autogenerated by uv via the following command:
|
||||
# uv export --frozen --no-hashes --no-emit-project --no-default-groups --output-file=requirements.txt
|
||||
aiohappyeyeballs==2.5.0
|
||||
# via aiohttp
|
||||
aiohttp==3.12.13
|
||||
# via llama-stack
|
||||
aiosignal==1.3.2
|
||||
# via aiohttp
|
||||
aiosqlite==0.21.0
|
||||
# via llama-stack
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
anyio==4.8.0
|
||||
# via
|
||||
# httpx
|
||||
# llama-api-client
|
||||
# llama-stack-client
|
||||
# openai
|
||||
# starlette
|
||||
asyncpg==0.30.0
|
||||
# via llama-stack
|
||||
attrs==25.1.0
|
||||
# via
|
||||
# aiohttp
|
||||
# jsonschema
|
||||
# referencing
|
||||
certifi==2025.1.31
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
# requests
|
||||
cffi==1.17.1 ; platform_python_implementation != 'PyPy'
|
||||
# via cryptography
|
||||
charset-normalizer==3.4.1
|
||||
# via requests
|
||||
click==8.1.8
|
||||
# via
|
||||
# llama-stack-client
|
||||
# uvicorn
|
||||
colorama==0.4.6 ; sys_platform == 'win32'
|
||||
# via
|
||||
# click
|
||||
# tqdm
|
||||
cryptography==45.0.5
|
||||
# via python-jose
|
||||
deprecated==1.2.18
|
||||
# via
|
||||
# opentelemetry-api
|
||||
# opentelemetry-exporter-otlp-proto-http
|
||||
# opentelemetry-semantic-conventions
|
||||
distro==1.9.0
|
||||
# via
|
||||
# llama-api-client
|
||||
# llama-stack-client
|
||||
# openai
|
||||
ecdsa==0.19.1
|
||||
# via python-jose
|
||||
fastapi==0.115.8
|
||||
# via llama-stack
|
||||
filelock==3.17.0
|
||||
# via huggingface-hub
|
||||
fire==0.7.0
|
||||
# via
|
||||
# llama-stack
|
||||
# llama-stack-client
|
||||
frozenlist==1.5.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec==2024.12.0
|
||||
# via huggingface-hub
|
||||
googleapis-common-protos==1.67.0
|
||||
# via opentelemetry-exporter-otlp-proto-http
|
||||
h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
# llama-stack
|
||||
# uvicorn
|
||||
hf-xet==1.1.5 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
|
||||
# via huggingface-hub
|
||||
httpcore==1.0.9
|
||||
# via httpx
|
||||
httpx==0.28.1
|
||||
# via
|
||||
# llama-api-client
|
||||
# llama-stack
|
||||
# llama-stack-client
|
||||
# openai
|
||||
huggingface-hub==0.34.1
|
||||
# via llama-stack
|
||||
idna==3.10
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
# requests
|
||||
# yarl
|
||||
importlib-metadata==8.5.0
|
||||
# via opentelemetry-api
|
||||
jinja2==3.1.6
|
||||
# via llama-stack
|
||||
jiter==0.8.2
|
||||
# via openai
|
||||
jsonschema==4.23.0
|
||||
# via llama-stack
|
||||
jsonschema-specifications==2024.10.1
|
||||
# via jsonschema
|
||||
llama-api-client==0.1.2
|
||||
# via llama-stack
|
||||
llama-stack-client==0.2.15
|
||||
# via llama-stack
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
markupsafe==3.0.2
|
||||
# via jinja2
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
multidict==6.1.0
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
numpy==2.2.3
|
||||
# via pandas
|
||||
openai==1.71.0
|
||||
# via llama-stack
|
||||
opentelemetry-api==1.30.0
|
||||
# via
|
||||
# opentelemetry-exporter-otlp-proto-http
|
||||
# opentelemetry-sdk
|
||||
# opentelemetry-semantic-conventions
|
||||
opentelemetry-exporter-otlp-proto-common==1.30.0
|
||||
# via opentelemetry-exporter-otlp-proto-http
|
||||
opentelemetry-exporter-otlp-proto-http==1.30.0
|
||||
# via llama-stack
|
||||
opentelemetry-proto==1.30.0
|
||||
# via
|
||||
# opentelemetry-exporter-otlp-proto-common
|
||||
# opentelemetry-exporter-otlp-proto-http
|
||||
opentelemetry-sdk==1.30.0
|
||||
# via
|
||||
# llama-stack
|
||||
# opentelemetry-exporter-otlp-proto-http
|
||||
opentelemetry-semantic-conventions==0.51b0
|
||||
# via opentelemetry-sdk
|
||||
packaging==24.2
|
||||
# via huggingface-hub
|
||||
pandas==2.2.3
|
||||
# via llama-stack-client
|
||||
pillow==11.1.0
|
||||
# via llama-stack
|
||||
prompt-toolkit==3.0.50
|
||||
# via
|
||||
# llama-stack
|
||||
# llama-stack-client
|
||||
propcache==0.3.0
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
protobuf==5.29.5
|
||||
# via
|
||||
# googleapis-common-protos
|
||||
# opentelemetry-proto
|
||||
pyaml==25.1.0
|
||||
# via llama-stack-client
|
||||
pyasn1==0.4.8
|
||||
# via
|
||||
# python-jose
|
||||
# rsa
|
||||
pycparser==2.22 ; platform_python_implementation != 'PyPy'
|
||||
# via cffi
|
||||
pydantic==2.10.6
|
||||
# via
|
||||
# fastapi
|
||||
# llama-api-client
|
||||
# llama-stack
|
||||
# llama-stack-client
|
||||
# openai
|
||||
pydantic-core==2.27.2
|
||||
# via pydantic
|
||||
pygments==2.19.1
|
||||
# via rich
|
||||
python-dateutil==2.9.0.post0
|
||||
# via pandas
|
||||
python-dotenv==1.0.1
|
||||
# via llama-stack
|
||||
python-jose==3.4.0
|
||||
# via llama-stack
|
||||
python-multipart==0.0.20
|
||||
# via llama-stack
|
||||
pytz==2025.1
|
||||
# via pandas
|
||||
pyyaml==6.0.2
|
||||
# via
|
||||
# huggingface-hub
|
||||
# pyaml
|
||||
referencing==0.36.2
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
regex==2024.11.6
|
||||
# via tiktoken
|
||||
requests==2.32.4
|
||||
# via
|
||||
# huggingface-hub
|
||||
# llama-stack-client
|
||||
# opentelemetry-exporter-otlp-proto-http
|
||||
# tiktoken
|
||||
rich==13.9.4
|
||||
# via
|
||||
# llama-stack
|
||||
# llama-stack-client
|
||||
rpds-py==0.22.3
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
rsa==4.9
|
||||
# via python-jose
|
||||
six==1.17.0
|
||||
# via
|
||||
# ecdsa
|
||||
# python-dateutil
|
||||
sniffio==1.3.1
|
||||
# via
|
||||
# anyio
|
||||
# llama-api-client
|
||||
# llama-stack-client
|
||||
# openai
|
||||
starlette==0.45.3
|
||||
# via
|
||||
# fastapi
|
||||
# llama-stack
|
||||
termcolor==2.5.0
|
||||
# via
|
||||
# fire
|
||||
# llama-stack
|
||||
# llama-stack-client
|
||||
tiktoken==0.9.0
|
||||
# via llama-stack
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# llama-stack-client
|
||||
# openai
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# aiosqlite
|
||||
# anyio
|
||||
# fastapi
|
||||
# huggingface-hub
|
||||
# llama-api-client
|
||||
# llama-stack-client
|
||||
# openai
|
||||
# opentelemetry-sdk
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# referencing
|
||||
tzdata==2025.1
|
||||
# via pandas
|
||||
urllib3==2.5.0
|
||||
# via requests
|
||||
uvicorn==0.34.0
|
||||
# via llama-stack
|
||||
wcwidth==0.2.13
|
||||
# via prompt-toolkit
|
||||
wrapt==1.17.2
|
||||
# via deprecated
|
||||
yarl==1.18.3
|
||||
# via aiohttp
|
||||
zipp==3.21.0
|
||||
# via importlib-metadata
|
|
@ -8,6 +8,15 @@
|
|||
|
||||
PYTHON_VERSION=${PYTHON_VERSION:-3.12}
|
||||
|
||||
set -e
|
||||
|
||||
# Always run this at the end, even if something fails
|
||||
cleanup() {
|
||||
echo "Generating coverage report..."
|
||||
uv run --python "$PYTHON_VERSION" coverage html -d htmlcov-$PYTHON_VERSION
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
command -v uv >/dev/null 2>&1 || { echo >&2 "uv is required but it's not installed. Exiting."; exit 1; }
|
||||
|
||||
uv python find "$PYTHON_VERSION"
|
||||
|
@ -19,6 +28,3 @@ fi
|
|||
# Run unit tests with coverage
|
||||
uv run --python "$PYTHON_VERSION" --with-editable . --group unit \
|
||||
coverage run --source=llama_stack -m pytest -s -v tests/unit/ "$@"
|
||||
|
||||
# Generate HTML coverage report
|
||||
uv run --python "$PYTHON_VERSION" coverage html -d htmlcov-$PYTHON_VERSION
|
||||
|
|
3
tests/external/build.yaml
vendored
3
tests/external/build.yaml
vendored
|
@ -3,8 +3,7 @@ distribution_spec:
|
|||
description: Custom distro for CI tests
|
||||
providers:
|
||||
weather:
|
||||
- provider_id: kaze
|
||||
provider_type: remote::kaze
|
||||
- provider_type: remote::kaze
|
||||
image_type: venv
|
||||
image_name: ci-test
|
||||
external_providers_dir: ~/.llama/providers.d
|
||||
|
|
3
tests/external/ramalama-stack/build.yaml
vendored
3
tests/external/ramalama-stack/build.yaml
vendored
|
@ -4,8 +4,7 @@ distribution_spec:
|
|||
container_image: null
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: ramalama
|
||||
provider_type: remote::ramalama
|
||||
- provider_type: remote::ramalama
|
||||
module: ramalama_stack==0.3.0a0
|
||||
image_type: venv
|
||||
image_name: ramalama-stack-test
|
||||
|
|
|
@ -5,8 +5,14 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import base64
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
from openai import OpenAI
|
||||
from reportlab.lib.pagesizes import letter
|
||||
from reportlab.pdfgen import canvas
|
||||
|
||||
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||
|
||||
|
@ -82,6 +88,14 @@ def skip_if_provider_isnt_vllm(client_with_models, model_id):
|
|||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support vllm extra_body parameters.")
|
||||
|
||||
|
||||
def skip_if_provider_isnt_openai(client_with_models, model_id):
|
||||
provider = provider_from_model(client_with_models, model_id)
|
||||
if provider.provider_type != "remote::openai":
|
||||
pytest.skip(
|
||||
f"Model {model_id} hosted by {provider.provider_type} doesn't support chat completion calls with base64 encoded files."
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_client(client_with_models):
|
||||
base_url = f"{client_with_models.base_url}/v1/openai/v1"
|
||||
|
@ -418,3 +432,45 @@ def test_inference_store_tool_calls(compat_client, client_with_models, text_mode
|
|||
# failed tool call parses show up as a message with content, so ensure
|
||||
# that the retrieve response content matches the original request
|
||||
assert retrieved_response.choices[0].message.content == content
|
||||
|
||||
|
||||
def test_openai_chat_completion_non_streaming_with_file(openai_client, client_with_models, text_model_id):
|
||||
skip_if_provider_isnt_openai(client_with_models, text_model_id)
|
||||
|
||||
# Generate temporary PDF with "Hello World" text
|
||||
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as temp_pdf:
|
||||
c = canvas.Canvas(temp_pdf.name, pagesize=letter)
|
||||
c.drawString(100, 750, "Hello World")
|
||||
c.save()
|
||||
|
||||
# Read the PDF and sencode to base64
|
||||
with open(temp_pdf.name, "rb") as pdf_file:
|
||||
pdf_base64 = base64.b64encode(pdf_file.read()).decode("utf-8")
|
||||
|
||||
# Clean up temporary file
|
||||
os.unlink(temp_pdf.name)
|
||||
|
||||
response = openai_client.chat.completions.create(
|
||||
model=text_model_id,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Describe what you see in this PDF file.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "file",
|
||||
"file": {
|
||||
"filename": "my-temp-hello-world-pdf",
|
||||
"file_data": f"data:application/pdf;base64,{pdf_base64}",
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
message_content = response.choices[0].message.content.lower().strip()
|
||||
assert "hello world" in message_content
|
||||
|
|
|
@ -38,9 +38,8 @@ sys.stdout.reconfigure(line_buffering=True)
|
|||
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest llama_stack/providers/tests/post_training/test_post_training.py
|
||||
# -m "torchtune_post_training_huggingface_datasetio"
|
||||
# -v -s --tb=short --disable-warnings
|
||||
# LLAMA_STACK_CONFIG=ci-tests uv run --dev pytest tests/integration/post_training/test_post_training.py
|
||||
#
|
||||
|
||||
|
||||
class TestPostTraining:
|
||||
|
@ -113,6 +112,7 @@ class TestPostTraining:
|
|||
break
|
||||
|
||||
logger.info(f"Current status: {status}")
|
||||
assert status.status in ["scheduled", "in_progress", "completed"]
|
||||
if status.status == "completed":
|
||||
break
|
||||
|
||||
|
|
|
@ -346,7 +346,7 @@ pip_packages:
|
|||
|
||||
def test_external_provider_from_module_building(self, mock_providers):
|
||||
"""Test loading an external provider from a module during build (building=True, partial spec)."""
|
||||
from llama_stack.distribution.datatypes import BuildConfig, DistributionSpec, Provider
|
||||
from llama_stack.distribution.datatypes import BuildConfig, BuildProvider, DistributionSpec
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
# No importlib patch needed, should not import module when type of `config` is BuildConfig or DistributionSpec
|
||||
|
@ -358,10 +358,8 @@ pip_packages:
|
|||
description="test",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="external_test",
|
||||
BuildProvider(
|
||||
provider_type="external_test",
|
||||
config={},
|
||||
module="external_test",
|
||||
)
|
||||
]
|
||||
|
|
125
tests/unit/providers/inference/test_openai_base_url_config.py
Normal file
125
tests/unit/providers/inference/test_openai_base_url_config.py
Normal file
|
@ -0,0 +1,125 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from llama_stack.distribution.stack import replace_env_vars
|
||||
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
||||
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
|
||||
|
||||
|
||||
class TestOpenAIBaseURLConfig:
|
||||
"""Test that OPENAI_BASE_URL environment variable properly configures the OpenAI adapter."""
|
||||
|
||||
def test_default_base_url_without_env_var(self):
|
||||
"""Test that the adapter uses the default OpenAI base URL when no environment variable is set."""
|
||||
config = OpenAIConfig(api_key="test-key")
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
|
||||
assert adapter.get_base_url() == "https://api.openai.com/v1"
|
||||
|
||||
def test_custom_base_url_from_config(self):
|
||||
"""Test that the adapter uses a custom base URL when provided in config."""
|
||||
custom_url = "https://custom.openai.com/v1"
|
||||
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
|
||||
assert adapter.get_base_url() == custom_url
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://env.openai.com/v1"})
|
||||
def test_base_url_from_environment_variable(self):
|
||||
"""Test that the adapter uses base URL from OPENAI_BASE_URL environment variable."""
|
||||
# Use sample_run_config which has proper environment variable syntax
|
||||
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
|
||||
processed_config = replace_env_vars(config_data)
|
||||
config = OpenAIConfig.model_validate(processed_config)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
|
||||
assert adapter.get_base_url() == "https://env.openai.com/v1"
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://env.openai.com/v1"})
|
||||
def test_config_overrides_environment_variable(self):
|
||||
"""Test that explicit config value overrides environment variable."""
|
||||
custom_url = "https://config.openai.com/v1"
|
||||
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
|
||||
# Config should take precedence over environment variable
|
||||
assert adapter.get_base_url() == custom_url
|
||||
|
||||
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
|
||||
def test_client_uses_configured_base_url(self, mock_openai_class):
|
||||
"""Test that the OpenAI client is initialized with the configured base URL."""
|
||||
custom_url = "https://test.openai.com/v1"
|
||||
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
|
||||
# Mock the get_api_key method since it's delegated to LiteLLMOpenAIMixin
|
||||
adapter.get_api_key = MagicMock(return_value="test-key")
|
||||
|
||||
# Access the client property to trigger AsyncOpenAI initialization
|
||||
_ = adapter.client
|
||||
|
||||
# Verify AsyncOpenAI was called with the correct base_url
|
||||
mock_openai_class.assert_called_once_with(
|
||||
api_key="test-key",
|
||||
base_url=custom_url,
|
||||
)
|
||||
|
||||
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
|
||||
async def test_check_model_availability_uses_configured_url(self, mock_openai_class):
|
||||
"""Test that check_model_availability uses the configured base URL."""
|
||||
custom_url = "https://test.openai.com/v1"
|
||||
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
|
||||
# Mock the get_api_key method
|
||||
adapter.get_api_key = MagicMock(return_value="test-key")
|
||||
|
||||
# Mock the AsyncOpenAI client and its models.retrieve method
|
||||
mock_client = MagicMock()
|
||||
mock_client.models.retrieve = AsyncMock(return_value=MagicMock())
|
||||
mock_openai_class.return_value = mock_client
|
||||
|
||||
# Call check_model_availability and verify it returns True
|
||||
assert await adapter.check_model_availability("gpt-4")
|
||||
|
||||
# Verify the client was created with the custom URL
|
||||
mock_openai_class.assert_called_with(
|
||||
api_key="test-key",
|
||||
base_url=custom_url,
|
||||
)
|
||||
|
||||
# Verify the method was called and returned True
|
||||
mock_client.models.retrieve.assert_called_once_with("gpt-4")
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://proxy.openai.com/v1"})
|
||||
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
|
||||
async def test_environment_variable_affects_model_availability_check(self, mock_openai_class):
|
||||
"""Test that setting OPENAI_BASE_URL environment variable affects where model availability is checked."""
|
||||
# Use sample_run_config which has proper environment variable syntax
|
||||
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
|
||||
processed_config = replace_env_vars(config_data)
|
||||
config = OpenAIConfig.model_validate(processed_config)
|
||||
adapter = OpenAIInferenceAdapter(config)
|
||||
|
||||
# Mock the get_api_key method
|
||||
adapter.get_api_key = MagicMock(return_value="test-key")
|
||||
|
||||
# Mock the AsyncOpenAI client
|
||||
mock_client = MagicMock()
|
||||
mock_client.models.retrieve = AsyncMock(return_value=MagicMock())
|
||||
mock_openai_class.return_value = mock_client
|
||||
|
||||
# Call check_model_availability and verify it returns True
|
||||
assert await adapter.check_model_availability("gpt-4")
|
||||
|
||||
# Verify the client was created with the environment variable URL
|
||||
mock_openai_class.assert_called_with(
|
||||
api_key="test-key",
|
||||
base_url="https://proxy.openai.com/v1",
|
||||
)
|
|
@ -4,13 +4,19 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from llama_stack.apis.common.content_types import TextContentItem
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIImageURL,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
|
@ -108,3 +114,71 @@ async def test_openai_messages_to_messages_with_content_list():
|
|||
assert llama_messages[0].content[0].text == "system message"
|
||||
assert llama_messages[1].content[0].text == "user message"
|
||||
assert llama_messages[2].content[0].text == "assistant message"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"message_class,kwargs",
|
||||
[
|
||||
(OpenAISystemMessageParam, {}),
|
||||
(OpenAIAssistantMessageParam, {}),
|
||||
(OpenAIDeveloperMessageParam, {}),
|
||||
(OpenAIUserMessageParam, {}),
|
||||
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
||||
],
|
||||
)
|
||||
def test_message_accepts_text_string(message_class, kwargs):
|
||||
"""Test that messages accept string text content."""
|
||||
msg = message_class(content="Test message", **kwargs)
|
||||
assert msg.content == "Test message"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"message_class,kwargs",
|
||||
[
|
||||
(OpenAISystemMessageParam, {}),
|
||||
(OpenAIAssistantMessageParam, {}),
|
||||
(OpenAIDeveloperMessageParam, {}),
|
||||
(OpenAIUserMessageParam, {}),
|
||||
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
||||
],
|
||||
)
|
||||
def test_message_accepts_text_list(message_class, kwargs):
|
||||
"""Test that messages accept list of text content parts."""
|
||||
content_list = [OpenAIChatCompletionContentPartTextParam(text="Test message")]
|
||||
msg = message_class(content=content_list, **kwargs)
|
||||
assert len(msg.content) == 1
|
||||
assert msg.content[0].text == "Test message"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"message_class,kwargs",
|
||||
[
|
||||
(OpenAISystemMessageParam, {}),
|
||||
(OpenAIAssistantMessageParam, {}),
|
||||
(OpenAIDeveloperMessageParam, {}),
|
||||
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
||||
],
|
||||
)
|
||||
def test_message_rejects_images(message_class, kwargs):
|
||||
"""Test that system, assistant, developer, and tool messages reject image content."""
|
||||
with pytest.raises(ValidationError):
|
||||
message_class(
|
||||
content=[
|
||||
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg"))
|
||||
],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def test_user_message_accepts_images():
|
||||
"""Test that user messages accept image content (unlike other message types)."""
|
||||
# List with images should work
|
||||
msg = OpenAIUserMessageParam(
|
||||
content=[
|
||||
OpenAIChatCompletionContentPartTextParam(text="Describe this image:"),
|
||||
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")),
|
||||
]
|
||||
)
|
||||
assert len(msg.content) == 2
|
||||
assert msg.content[0].text == "Describe this image:"
|
||||
assert msg.content[1].image_url.url == "http://example.com/image.jpg"
|
||||
|
|
|
@ -162,26 +162,29 @@ async def test_register_model_existing_different(
|
|||
await helper.register_model(known_model)
|
||||
|
||||
|
||||
async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
await helper.register_model(known_model) # duplicate entry
|
||||
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_model_id
|
||||
await helper.unregister_model(known_model.model_id)
|
||||
assert helper.get_provider_model_id(known_model.model_id) is None
|
||||
# TODO: unregister_model functionality was removed/disabled by https://github.com/meta-llama/llama-stack/pull/2916
|
||||
# async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
# await helper.register_model(known_model) # duplicate entry
|
||||
# assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_model_id
|
||||
# await helper.unregister_model(known_model.model_id)
|
||||
# assert helper.get_provider_model_id(known_model.model_id) is None
|
||||
|
||||
|
||||
async def test_unregister_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
await helper.unregister_model(unknown_model.model_id)
|
||||
# TODO: unregister_model functionality was removed/disabled by https://github.com/meta-llama/llama-stack/pull/2916
|
||||
# async def test_unregister_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
|
||||
# with pytest.raises(ValueError):
|
||||
# await helper.unregister_model(unknown_model.model_id)
|
||||
|
||||
|
||||
async def test_register_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
||||
|
||||
|
||||
async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
||||
await helper.unregister_model(known_model.provider_resource_id)
|
||||
assert helper.get_provider_model_id(known_model.provider_resource_id) is None
|
||||
# TODO: unregister_model functionality was removed/disabled by https://github.com/meta-llama/llama-stack/pull/2916
|
||||
# async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
# assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
||||
# await helper.unregister_model(known_model.provider_resource_id)
|
||||
# assert helper.get_provider_model_id(known_model.provider_resource_id) is None
|
||||
|
||||
|
||||
async def test_register_model_from_check_model_availability(
|
||||
|
|
|
@ -49,7 +49,7 @@ def github_token_app():
|
|||
)
|
||||
|
||||
# Add auth middleware
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
|
@ -149,7 +149,7 @@ def test_github_enterprise_support(mock_client_class):
|
|||
access_policy=[],
|
||||
)
|
||||
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue