From 968fc132d3e0a824b45709c063afbdc2e55623ca Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Mon, 28 Jul 2025 13:36:34 -0400 Subject: [PATCH] fix(openai-compat): restrict developer/assistant/system/tool messages to text-only content (#2932) **What:** - Added OpenAIChatCompletionTextOnlyMessageContent type for text-only content validation - Modified OpenAISystemMessageParam, OpenAIAssistantMessageParam, OpenAIDeveloperMessageParam, and OpenAIToolMessageParam to use text-only content type instead of mixed content - OpenAIUserMessageParam unchanged - still accepts both text and images - Updated OpenAPI spec files to reflect text-only content restrictions in schemas closes #2894 **Why:** - Enforces OpenAI API compatibility by restricting image content to user messages only - Prevents API misuse where images might be sent in message types that don't support them - Aligns with OpenAI's actual API behavior where only user messages can contain multimodal content - Improves type safety and validation at the API boundary **Test plan:** - Added comprehensive parametrized tests covering all 5 OpenAI message types - Tests verify text string acceptance for all message types - Tests verify text list acceptance for all message types - Tests verify image rejection for system/assistant/developer/tool messages (ValidationError expected) - Tests verify user messages still accept images (backward compatibility maintained) --- docs/_static/llama-stack-spec.html | 8 +- docs/_static/llama-stack-spec.yaml | 8 +- llama_stack/apis/inference/inference.py | 10 ++- .../utils/inference/test_openai_compat.py | 74 +++++++++++++++++++ 4 files changed, 88 insertions(+), 12 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 38e53a438..60f970782 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -9770,7 +9770,7 @@ { "type": "array", "items": { - "$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam" + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam" } } ], @@ -9955,7 +9955,7 @@ { "type": "array", "items": { - "$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam" + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam" } } ], @@ -10036,7 +10036,7 @@ { "type": "array", "items": { - "$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam" + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam" } } ], @@ -10107,7 +10107,7 @@ { "type": "array", "items": { - "$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam" + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam" } } ], diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 0df60ddf4..36e432ab3 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -6895,7 +6895,7 @@ components: - type: string - type: array items: - $ref: '#/components/schemas/OpenAIChatCompletionContentPartParam' + $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam' description: The content of the model's response name: type: string @@ -7037,7 +7037,7 @@ components: - type: string - type: array items: - $ref: '#/components/schemas/OpenAIChatCompletionContentPartParam' + $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam' description: The content of the developer message name: type: string @@ -7090,7 +7090,7 @@ components: - type: string - type: array items: - $ref: '#/components/schemas/OpenAIChatCompletionContentPartParam' + $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam' description: >- The content of the "system prompt". If multiple system messages are provided, they are concatenated. The underlying Llama Stack code may also add other @@ -7148,7 +7148,7 @@ components: - type: string - type: array items: - $ref: '#/components/schemas/OpenAIChatCompletionContentPartParam' + $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam' description: The response content from the tool additionalProperties: false required: diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 222099064..796fee65d 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -464,6 +464,8 @@ register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletion OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam] +OpenAIChatCompletionTextOnlyMessageContent = str | list[OpenAIChatCompletionContentPartTextParam] + @json_schema_type class OpenAIUserMessageParam(BaseModel): @@ -489,7 +491,7 @@ class OpenAISystemMessageParam(BaseModel): """ role: Literal["system"] = "system" - content: OpenAIChatCompletionMessageContent + content: OpenAIChatCompletionTextOnlyMessageContent name: str | None = None @@ -518,7 +520,7 @@ class OpenAIAssistantMessageParam(BaseModel): """ role: Literal["assistant"] = "assistant" - content: OpenAIChatCompletionMessageContent | None = None + content: OpenAIChatCompletionTextOnlyMessageContent | None = None name: str | None = None tool_calls: list[OpenAIChatCompletionToolCall] | None = None @@ -534,7 +536,7 @@ class OpenAIToolMessageParam(BaseModel): role: Literal["tool"] = "tool" tool_call_id: str - content: OpenAIChatCompletionMessageContent + content: OpenAIChatCompletionTextOnlyMessageContent @json_schema_type @@ -547,7 +549,7 @@ class OpenAIDeveloperMessageParam(BaseModel): """ role: Literal["developer"] = "developer" - content: OpenAIChatCompletionMessageContent + content: OpenAIChatCompletionTextOnlyMessageContent name: str | None = None diff --git a/tests/unit/providers/utils/inference/test_openai_compat.py b/tests/unit/providers/utils/inference/test_openai_compat.py index f57f6c9b3..5b8527d1b 100644 --- a/tests/unit/providers/utils/inference/test_openai_compat.py +++ b/tests/unit/providers/utils/inference/test_openai_compat.py @@ -4,13 +4,19 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import pytest +from pydantic import ValidationError from llama_stack.apis.common.content_types import TextContentItem from llama_stack.apis.inference import ( CompletionMessage, OpenAIAssistantMessageParam, + OpenAIChatCompletionContentPartImageParam, OpenAIChatCompletionContentPartTextParam, + OpenAIDeveloperMessageParam, + OpenAIImageURL, OpenAISystemMessageParam, + OpenAIToolMessageParam, OpenAIUserMessageParam, SystemMessage, UserMessage, @@ -108,3 +114,71 @@ async def test_openai_messages_to_messages_with_content_list(): assert llama_messages[0].content[0].text == "system message" assert llama_messages[1].content[0].text == "user message" assert llama_messages[2].content[0].text == "assistant message" + + +@pytest.mark.parametrize( + "message_class,kwargs", + [ + (OpenAISystemMessageParam, {}), + (OpenAIAssistantMessageParam, {}), + (OpenAIDeveloperMessageParam, {}), + (OpenAIUserMessageParam, {}), + (OpenAIToolMessageParam, {"tool_call_id": "call_123"}), + ], +) +def test_message_accepts_text_string(message_class, kwargs): + """Test that messages accept string text content.""" + msg = message_class(content="Test message", **kwargs) + assert msg.content == "Test message" + + +@pytest.mark.parametrize( + "message_class,kwargs", + [ + (OpenAISystemMessageParam, {}), + (OpenAIAssistantMessageParam, {}), + (OpenAIDeveloperMessageParam, {}), + (OpenAIUserMessageParam, {}), + (OpenAIToolMessageParam, {"tool_call_id": "call_123"}), + ], +) +def test_message_accepts_text_list(message_class, kwargs): + """Test that messages accept list of text content parts.""" + content_list = [OpenAIChatCompletionContentPartTextParam(text="Test message")] + msg = message_class(content=content_list, **kwargs) + assert len(msg.content) == 1 + assert msg.content[0].text == "Test message" + + +@pytest.mark.parametrize( + "message_class,kwargs", + [ + (OpenAISystemMessageParam, {}), + (OpenAIAssistantMessageParam, {}), + (OpenAIDeveloperMessageParam, {}), + (OpenAIToolMessageParam, {"tool_call_id": "call_123"}), + ], +) +def test_message_rejects_images(message_class, kwargs): + """Test that system, assistant, developer, and tool messages reject image content.""" + with pytest.raises(ValidationError): + message_class( + content=[ + OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")) + ], + **kwargs, + ) + + +def test_user_message_accepts_images(): + """Test that user messages accept image content (unlike other message types).""" + # List with images should work + msg = OpenAIUserMessageParam( + content=[ + OpenAIChatCompletionContentPartTextParam(text="Describe this image:"), + OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")), + ] + ) + assert len(msg.content) == 2 + assert msg.content[0].text == "Describe this image:" + assert msg.content[1].image_url.url == "http://example.com/image.jpg"