mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
fix(openai-compat): restrict developer/assistant/system/tool messages to text-only content (#2932)
**What:** - Added OpenAIChatCompletionTextOnlyMessageContent type for text-only content validation - Modified OpenAISystemMessageParam, OpenAIAssistantMessageParam, OpenAIDeveloperMessageParam, and OpenAIToolMessageParam to use text-only content type instead of mixed content - OpenAIUserMessageParam unchanged - still accepts both text and images - Updated OpenAPI spec files to reflect text-only content restrictions in schemas closes #2894 **Why:** - Enforces OpenAI API compatibility by restricting image content to user messages only - Prevents API misuse where images might be sent in message types that don't support them - Aligns with OpenAI's actual API behavior where only user messages can contain multimodal content - Improves type safety and validation at the API boundary **Test plan:** - Added comprehensive parametrized tests covering all 5 OpenAI message types - Tests verify text string acceptance for all message types - Tests verify text list acceptance for all message types - Tests verify image rejection for system/assistant/developer/tool messages (ValidationError expected) - Tests verify user messages still accept images (backward compatibility maintained)
This commit is contained in:
parent
60bb5e307e
commit
968fc132d3
4 changed files with 88 additions and 12 deletions
8
docs/_static/llama-stack-spec.html
vendored
8
docs/_static/llama-stack-spec.html
vendored
|
@ -9770,7 +9770,7 @@
|
||||||
{
|
{
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
|
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -9955,7 +9955,7 @@
|
||||||
{
|
{
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
|
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -10036,7 +10036,7 @@
|
||||||
{
|
{
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
|
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -10107,7 +10107,7 @@
|
||||||
{
|
{
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam"
|
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
8
docs/_static/llama-stack-spec.yaml
vendored
8
docs/_static/llama-stack-spec.yaml
vendored
|
@ -6895,7 +6895,7 @@ components:
|
||||||
- type: string
|
- type: string
|
||||||
- type: array
|
- type: array
|
||||||
items:
|
items:
|
||||||
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
|
$ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
||||||
description: The content of the model's response
|
description: The content of the model's response
|
||||||
name:
|
name:
|
||||||
type: string
|
type: string
|
||||||
|
@ -7037,7 +7037,7 @@ components:
|
||||||
- type: string
|
- type: string
|
||||||
- type: array
|
- type: array
|
||||||
items:
|
items:
|
||||||
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
|
$ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
||||||
description: The content of the developer message
|
description: The content of the developer message
|
||||||
name:
|
name:
|
||||||
type: string
|
type: string
|
||||||
|
@ -7090,7 +7090,7 @@ components:
|
||||||
- type: string
|
- type: string
|
||||||
- type: array
|
- type: array
|
||||||
items:
|
items:
|
||||||
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
|
$ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
||||||
description: >-
|
description: >-
|
||||||
The content of the "system prompt". If multiple system messages are provided,
|
The content of the "system prompt". If multiple system messages are provided,
|
||||||
they are concatenated. The underlying Llama Stack code may also add other
|
they are concatenated. The underlying Llama Stack code may also add other
|
||||||
|
@ -7148,7 +7148,7 @@ components:
|
||||||
- type: string
|
- type: string
|
||||||
- type: array
|
- type: array
|
||||||
items:
|
items:
|
||||||
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam'
|
$ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
||||||
description: The response content from the tool
|
description: The response content from the tool
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
|
|
|
@ -464,6 +464,8 @@ register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletion
|
||||||
|
|
||||||
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
|
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
|
||||||
|
|
||||||
|
OpenAIChatCompletionTextOnlyMessageContent = str | list[OpenAIChatCompletionContentPartTextParam]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIUserMessageParam(BaseModel):
|
class OpenAIUserMessageParam(BaseModel):
|
||||||
|
@ -489,7 +491,7 @@ class OpenAISystemMessageParam(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
role: Literal["system"] = "system"
|
role: Literal["system"] = "system"
|
||||||
content: OpenAIChatCompletionMessageContent
|
content: OpenAIChatCompletionTextOnlyMessageContent
|
||||||
name: str | None = None
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -518,7 +520,7 @@ class OpenAIAssistantMessageParam(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
role: Literal["assistant"] = "assistant"
|
role: Literal["assistant"] = "assistant"
|
||||||
content: OpenAIChatCompletionMessageContent | None = None
|
content: OpenAIChatCompletionTextOnlyMessageContent | None = None
|
||||||
name: str | None = None
|
name: str | None = None
|
||||||
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
|
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
|
||||||
|
|
||||||
|
@ -534,7 +536,7 @@ class OpenAIToolMessageParam(BaseModel):
|
||||||
|
|
||||||
role: Literal["tool"] = "tool"
|
role: Literal["tool"] = "tool"
|
||||||
tool_call_id: str
|
tool_call_id: str
|
||||||
content: OpenAIChatCompletionMessageContent
|
content: OpenAIChatCompletionTextOnlyMessageContent
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -547,7 +549,7 @@ class OpenAIDeveloperMessageParam(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
role: Literal["developer"] = "developer"
|
role: Literal["developer"] = "developer"
|
||||||
content: OpenAIChatCompletionMessageContent
|
content: OpenAIChatCompletionTextOnlyMessageContent
|
||||||
name: str | None = None
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,13 +4,19 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import TextContentItem
|
from llama_stack.apis.common.content_types import TextContentItem
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
OpenAIAssistantMessageParam,
|
OpenAIAssistantMessageParam,
|
||||||
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
OpenAIChatCompletionContentPartTextParam,
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
|
OpenAIDeveloperMessageParam,
|
||||||
|
OpenAIImageURL,
|
||||||
OpenAISystemMessageParam,
|
OpenAISystemMessageParam,
|
||||||
|
OpenAIToolMessageParam,
|
||||||
OpenAIUserMessageParam,
|
OpenAIUserMessageParam,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
|
@ -108,3 +114,71 @@ async def test_openai_messages_to_messages_with_content_list():
|
||||||
assert llama_messages[0].content[0].text == "system message"
|
assert llama_messages[0].content[0].text == "system message"
|
||||||
assert llama_messages[1].content[0].text == "user message"
|
assert llama_messages[1].content[0].text == "user message"
|
||||||
assert llama_messages[2].content[0].text == "assistant message"
|
assert llama_messages[2].content[0].text == "assistant message"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"message_class,kwargs",
|
||||||
|
[
|
||||||
|
(OpenAISystemMessageParam, {}),
|
||||||
|
(OpenAIAssistantMessageParam, {}),
|
||||||
|
(OpenAIDeveloperMessageParam, {}),
|
||||||
|
(OpenAIUserMessageParam, {}),
|
||||||
|
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_message_accepts_text_string(message_class, kwargs):
|
||||||
|
"""Test that messages accept string text content."""
|
||||||
|
msg = message_class(content="Test message", **kwargs)
|
||||||
|
assert msg.content == "Test message"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"message_class,kwargs",
|
||||||
|
[
|
||||||
|
(OpenAISystemMessageParam, {}),
|
||||||
|
(OpenAIAssistantMessageParam, {}),
|
||||||
|
(OpenAIDeveloperMessageParam, {}),
|
||||||
|
(OpenAIUserMessageParam, {}),
|
||||||
|
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_message_accepts_text_list(message_class, kwargs):
|
||||||
|
"""Test that messages accept list of text content parts."""
|
||||||
|
content_list = [OpenAIChatCompletionContentPartTextParam(text="Test message")]
|
||||||
|
msg = message_class(content=content_list, **kwargs)
|
||||||
|
assert len(msg.content) == 1
|
||||||
|
assert msg.content[0].text == "Test message"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"message_class,kwargs",
|
||||||
|
[
|
||||||
|
(OpenAISystemMessageParam, {}),
|
||||||
|
(OpenAIAssistantMessageParam, {}),
|
||||||
|
(OpenAIDeveloperMessageParam, {}),
|
||||||
|
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_message_rejects_images(message_class, kwargs):
|
||||||
|
"""Test that system, assistant, developer, and tool messages reject image content."""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
message_class(
|
||||||
|
content=[
|
||||||
|
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg"))
|
||||||
|
],
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_message_accepts_images():
|
||||||
|
"""Test that user messages accept image content (unlike other message types)."""
|
||||||
|
# List with images should work
|
||||||
|
msg = OpenAIUserMessageParam(
|
||||||
|
content=[
|
||||||
|
OpenAIChatCompletionContentPartTextParam(text="Describe this image:"),
|
||||||
|
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert len(msg.content) == 2
|
||||||
|
assert msg.content[0].text == "Describe this image:"
|
||||||
|
assert msg.content[1].image_url.url == "http://example.com/image.jpg"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue