diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index 38e53a438..60f970782 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -9770,7 +9770,7 @@
{
"type": "array",
"items": {
- "$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
+ "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
}
}
],
@@ -9955,7 +9955,7 @@
{
"type": "array",
"items": {
- "$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
+ "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
}
}
],
@@ -10036,7 +10036,7 @@
{
"type": "array",
"items": {
- "$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
+ "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
}
}
],
@@ -10107,7 +10107,7 @@
{
"type": "array",
"items": {
- "$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
+ "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
}
}
],
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 0df60ddf4..36e432ab3 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -6895,7 +6895,7 @@ components:
- type: string
- type: array
items:
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
+ $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
description: The content of the model's response
name:
type: string
@@ -7037,7 +7037,7 @@ components:
- type: string
- type: array
items:
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
+ $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
description: The content of the developer message
name:
type: string
@@ -7090,7 +7090,7 @@ components:
- type: string
- type: array
items:
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
+ $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
description: >-
The content of the "system prompt". If multiple system messages are provided,
they are concatenated. The underlying Llama Stack code may also add other
@@ -7148,7 +7148,7 @@ components:
- type: string
- type: array
items:
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
+ $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
description: The response content from the tool
additionalProperties: false
required:
diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py
index 222099064..796fee65d 100644
--- a/llama_stack/apis/inference/inference.py
+++ b/llama_stack/apis/inference/inference.py
@@ -464,6 +464,8 @@ register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletion
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
+OpenAIChatCompletionTextOnlyMessageContent = str | list[OpenAIChatCompletionContentPartTextParam]
+
@json_schema_type
class OpenAIUserMessageParam(BaseModel):
@@ -489,7 +491,7 @@ class OpenAISystemMessageParam(BaseModel):
"""
role: Literal["system"] = "system"
- content: OpenAIChatCompletionMessageContent
+ content: OpenAIChatCompletionTextOnlyMessageContent
name: str | None = None
@@ -518,7 +520,7 @@ class OpenAIAssistantMessageParam(BaseModel):
"""
role: Literal["assistant"] = "assistant"
- content: OpenAIChatCompletionMessageContent | None = None
+ content: OpenAIChatCompletionTextOnlyMessageContent | None = None
name: str | None = None
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
@@ -534,7 +536,7 @@ class OpenAIToolMessageParam(BaseModel):
role: Literal["tool"] = "tool"
tool_call_id: str
- content: OpenAIChatCompletionMessageContent
+ content: OpenAIChatCompletionTextOnlyMessageContent
@json_schema_type
@@ -547,7 +549,7 @@ class OpenAIDeveloperMessageParam(BaseModel):
"""
role: Literal["developer"] = "developer"
- content: OpenAIChatCompletionMessageContent
+ content: OpenAIChatCompletionTextOnlyMessageContent
name: str | None = None
diff --git a/tests/unit/providers/utils/inference/test_openai_compat.py b/tests/unit/providers/utils/inference/test_openai_compat.py
index f57f6c9b3..5b8527d1b 100644
--- a/tests/unit/providers/utils/inference/test_openai_compat.py
+++ b/tests/unit/providers/utils/inference/test_openai_compat.py
@@ -4,13 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
+import pytest
+from pydantic import ValidationError
from llama_stack.apis.common.content_types import TextContentItem
from llama_stack.apis.inference import (
CompletionMessage,
OpenAIAssistantMessageParam,
+ OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
+ OpenAIDeveloperMessageParam,
+ OpenAIImageURL,
OpenAISystemMessageParam,
+ OpenAIToolMessageParam,
OpenAIUserMessageParam,
SystemMessage,
UserMessage,
@@ -108,3 +114,71 @@ async def test_openai_messages_to_messages_with_content_list():
assert llama_messages[0].content[0].text == "system message"
assert llama_messages[1].content[0].text == "user message"
assert llama_messages[2].content[0].text == "assistant message"
+
+
+@pytest.mark.parametrize(
+ "message_class,kwargs",
+ [
+ (OpenAISystemMessageParam, {}),
+ (OpenAIAssistantMessageParam, {}),
+ (OpenAIDeveloperMessageParam, {}),
+ (OpenAIUserMessageParam, {}),
+ (OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
+ ],
+)
+def test_message_accepts_text_string(message_class, kwargs):
+ """Test that messages accept string text content."""
+ msg = message_class(content="Test message", **kwargs)
+ assert msg.content == "Test message"
+
+
+@pytest.mark.parametrize(
+ "message_class,kwargs",
+ [
+ (OpenAISystemMessageParam, {}),
+ (OpenAIAssistantMessageParam, {}),
+ (OpenAIDeveloperMessageParam, {}),
+ (OpenAIUserMessageParam, {}),
+ (OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
+ ],
+)
+def test_message_accepts_text_list(message_class, kwargs):
+ """Test that messages accept list of text content parts."""
+ content_list = [OpenAIChatCompletionContentPartTextParam(text="Test message")]
+ msg = message_class(content=content_list, **kwargs)
+ assert len(msg.content) == 1
+ assert msg.content[0].text == "Test message"
+
+
+@pytest.mark.parametrize(
+ "message_class,kwargs",
+ [
+ (OpenAISystemMessageParam, {}),
+ (OpenAIAssistantMessageParam, {}),
+ (OpenAIDeveloperMessageParam, {}),
+ (OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
+ ],
+)
+def test_message_rejects_images(message_class, kwargs):
+ """Test that system, assistant, developer, and tool messages reject image content."""
+ with pytest.raises(ValidationError):
+ message_class(
+ content=[
+ OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg"))
+ ],
+ **kwargs,
+ )
+
+
+def test_user_message_accepts_images():
+ """Test that user messages accept image content (unlike other message types)."""
+ # List with images should work
+ msg = OpenAIUserMessageParam(
+ content=[
+ OpenAIChatCompletionContentPartTextParam(text="Describe this image:"),
+ OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")),
+ ]
+ )
+ assert len(msg.content) == 2
+ assert msg.content[0].text == "Describe this image:"
+ assert msg.content[1].image_url.url == "http://example.com/image.jpg"