diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 38e53a438..60f970782 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -9770,7 +9770,7 @@ { "type": "array", "items": { - "$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam" + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam" } } ], @@ -9955,7 +9955,7 @@ { "type": "array", "items": { - "$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam" + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam" } } ], @@ -10036,7 +10036,7 @@ { "type": "array", "items": { - "$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam" + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam" } } ], @@ -10107,7 +10107,7 @@ { "type": "array", "items": { - "$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam" + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam" } } ], diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 0df60ddf4..36e432ab3 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -6895,7 +6895,7 @@ components: - type: string - type: array items: - $ref: '#/components/schemas/OpenAIChatCompletionContentPartParam' + $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam' description: The content of the model's response name: type: string @@ -7037,7 +7037,7 @@ components: - type: string - type: array items: - $ref: '#/components/schemas/OpenAIChatCompletionContentPartParam' + $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam' description: The content of the developer message name: type: string @@ -7090,7 +7090,7 @@ components: - type: string - type: array items: - $ref: '#/components/schemas/OpenAIChatCompletionContentPartParam' + $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam' description: >- The content of the "system prompt". If multiple system messages are provided, they are concatenated. The underlying Llama Stack code may also add other @@ -7148,7 +7148,7 @@ components: - type: string - type: array items: - $ref: '#/components/schemas/OpenAIChatCompletionContentPartParam' + $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam' description: The response content from the tool additionalProperties: false required: diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 222099064..796fee65d 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -464,6 +464,8 @@ register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletion OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam] +OpenAIChatCompletionTextOnlyMessageContent = str | list[OpenAIChatCompletionContentPartTextParam] + @json_schema_type class OpenAIUserMessageParam(BaseModel): @@ -489,7 +491,7 @@ class OpenAISystemMessageParam(BaseModel): """ role: Literal["system"] = "system" - content: OpenAIChatCompletionMessageContent + content: OpenAIChatCompletionTextOnlyMessageContent name: str | None = None @@ -518,7 +520,7 @@ class OpenAIAssistantMessageParam(BaseModel): """ role: Literal["assistant"] = "assistant" - content: OpenAIChatCompletionMessageContent | None = None + content: OpenAIChatCompletionTextOnlyMessageContent | None = None name: str | None = None tool_calls: list[OpenAIChatCompletionToolCall] | None = None @@ -534,7 +536,7 @@ class OpenAIToolMessageParam(BaseModel): role: Literal["tool"] = "tool" tool_call_id: str - content: OpenAIChatCompletionMessageContent + content: OpenAIChatCompletionTextOnlyMessageContent @json_schema_type @@ -547,7 +549,7 @@ class OpenAIDeveloperMessageParam(BaseModel): """ role: Literal["developer"] = "developer" - content: OpenAIChatCompletionMessageContent + content: OpenAIChatCompletionTextOnlyMessageContent name: str | None = None diff --git a/tests/unit/providers/utils/inference/test_openai_compat.py b/tests/unit/providers/utils/inference/test_openai_compat.py index f57f6c9b3..5b8527d1b 100644 --- a/tests/unit/providers/utils/inference/test_openai_compat.py +++ b/tests/unit/providers/utils/inference/test_openai_compat.py @@ -4,13 +4,19 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import pytest +from pydantic import ValidationError from llama_stack.apis.common.content_types import TextContentItem from llama_stack.apis.inference import ( CompletionMessage, OpenAIAssistantMessageParam, + OpenAIChatCompletionContentPartImageParam, OpenAIChatCompletionContentPartTextParam, + OpenAIDeveloperMessageParam, + OpenAIImageURL, OpenAISystemMessageParam, + OpenAIToolMessageParam, OpenAIUserMessageParam, SystemMessage, UserMessage, @@ -108,3 +114,71 @@ async def test_openai_messages_to_messages_with_content_list(): assert llama_messages[0].content[0].text == "system message" assert llama_messages[1].content[0].text == "user message" assert llama_messages[2].content[0].text == "assistant message" + + +@pytest.mark.parametrize( + "message_class,kwargs", + [ + (OpenAISystemMessageParam, {}), + (OpenAIAssistantMessageParam, {}), + (OpenAIDeveloperMessageParam, {}), + (OpenAIUserMessageParam, {}), + (OpenAIToolMessageParam, {"tool_call_id": "call_123"}), + ], +) +def test_message_accepts_text_string(message_class, kwargs): + """Test that messages accept string text content.""" + msg = message_class(content="Test message", **kwargs) + assert msg.content == "Test message" + + +@pytest.mark.parametrize( + "message_class,kwargs", + [ + (OpenAISystemMessageParam, {}), + (OpenAIAssistantMessageParam, {}), + (OpenAIDeveloperMessageParam, {}), + (OpenAIUserMessageParam, {}), + (OpenAIToolMessageParam, {"tool_call_id": "call_123"}), + ], +) +def test_message_accepts_text_list(message_class, kwargs): + """Test that messages accept list of text content parts.""" + content_list = [OpenAIChatCompletionContentPartTextParam(text="Test message")] + msg = message_class(content=content_list, **kwargs) + assert len(msg.content) == 1 + assert msg.content[0].text == "Test message" + + +@pytest.mark.parametrize( + "message_class,kwargs", + [ + (OpenAISystemMessageParam, {}), + (OpenAIAssistantMessageParam, {}), + (OpenAIDeveloperMessageParam, {}), + (OpenAIToolMessageParam, {"tool_call_id": "call_123"}), + ], +) +def test_message_rejects_images(message_class, kwargs): + """Test that system, assistant, developer, and tool messages reject image content.""" + with pytest.raises(ValidationError): + message_class( + content=[ + OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")) + ], + **kwargs, + ) + + +def test_user_message_accepts_images(): + """Test that user messages accept image content (unlike other message types).""" + # List with images should work + msg = OpenAIUserMessageParam( + content=[ + OpenAIChatCompletionContentPartTextParam(text="Describe this image:"), + OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")), + ] + ) + assert len(msg.content) == 2 + assert msg.content[0].text == "Describe this image:" + assert msg.content[1].image_url.url == "http://example.com/image.jpg"