mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 06:00:48 +00:00
refactor(responses): move stuff into some utils and add unit tests (#3158)
# What does this PR do? Refactors the OpenAI response conversion utilities by moving helper functions from `openai_responses.py` to `utils.py`. Adds unit tests.
This commit is contained in:
parent
47d5af703c
commit
9324e902f1
3 changed files with 435 additions and 120 deletions
|
@ -19,9 +19,6 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
MCPListToolsTool,
|
||||
OpenAIDeleteResponseObject,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseInputMessageContent,
|
||||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolMCP,
|
||||
|
@ -29,9 +26,6 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseOutput,
|
||||
OpenAIResponseOutputMessageContent,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
|
@ -39,23 +33,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
)
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIImageURL,
|
||||
OpenAIJSONSchema,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatJSONObject,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIResponseFormatParam,
|
||||
OpenAIResponseFormatText,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools import Tool, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
|
@ -69,106 +47,14 @@ from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
|||
from .streaming import StreamingResponseOrchestrator
|
||||
from .tool_executor import ToolExecutor
|
||||
from .types import ChatCompletionContext
|
||||
from .utils import (
|
||||
convert_response_input_to_chat_messages,
|
||||
convert_response_text_to_chat_response_format,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="responses")
|
||||
|
||||
|
||||
async def _convert_response_content_to_chat_content(
|
||||
content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
|
||||
) -> str | list[OpenAIChatCompletionContentPartParam]:
|
||||
"""
|
||||
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
|
||||
|
||||
The content schemas of each API look similar, but are not exactly the same.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
converted_parts = []
|
||||
for content_part in content:
|
||||
if isinstance(content_part, OpenAIResponseInputMessageContentText):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||
elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||
elif isinstance(content_part, OpenAIResponseInputMessageContentImage):
|
||||
if content_part.image_url:
|
||||
image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail)
|
||||
converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
|
||||
elif isinstance(content_part, str):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support content type '{type(content_part)}' in this context"
|
||||
)
|
||||
return converted_parts
|
||||
|
||||
|
||||
async def _convert_response_input_to_chat_messages(
|
||||
input: str | list[OpenAIResponseInput],
|
||||
) -> list[OpenAIMessageParam]:
|
||||
"""
|
||||
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
|
||||
"""
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
if isinstance(input, list):
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
||||
messages.append(
|
||||
OpenAIToolMessageParam(
|
||||
content=input_item.output,
|
||||
tool_call_id=input_item.call_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id=input_item.call_id,
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name=input_item.name,
|
||||
arguments=input_item.arguments,
|
||||
),
|
||||
)
|
||||
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||
else:
|
||||
content = await _convert_response_content_to_chat_content(input_item.content)
|
||||
message_type = await _get_message_type_by_role(input_item.role)
|
||||
if message_type is None:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
||||
)
|
||||
messages.append(message_type(content=content))
|
||||
else:
|
||||
messages.append(OpenAIUserMessageParam(content=input))
|
||||
return messages
|
||||
|
||||
|
||||
async def _convert_response_text_to_chat_response_format(
|
||||
text: OpenAIResponseText,
|
||||
) -> OpenAIResponseFormatParam:
|
||||
"""
|
||||
Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format.
|
||||
"""
|
||||
if not text.format or text.format["type"] == "text":
|
||||
return OpenAIResponseFormatText(type="text")
|
||||
if text.format["type"] == "json_object":
|
||||
return OpenAIResponseFormatJSONObject()
|
||||
if text.format["type"] == "json_schema":
|
||||
return OpenAIResponseFormatJSONSchema(
|
||||
json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"])
|
||||
)
|
||||
raise ValueError(f"Unsupported text format: {text.format}")
|
||||
|
||||
|
||||
async def _get_message_type_by_role(role: str):
|
||||
role_to_type = {
|
||||
"user": OpenAIUserMessageParam,
|
||||
"system": OpenAISystemMessageParam,
|
||||
"assistant": OpenAIAssistantMessageParam,
|
||||
"developer": OpenAIDeveloperMessageParam,
|
||||
}
|
||||
return role_to_type.get(role)
|
||||
|
||||
|
||||
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
||||
input_items: ListOpenAIResponseInputItem
|
||||
response: OpenAIResponseObject
|
||||
|
@ -350,11 +236,11 @@ class OpenAIResponsesImpl:
|
|||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Input preprocessing
|
||||
input = await self._prepend_previous_response(input, previous_response_id)
|
||||
messages = await _convert_response_input_to_chat_messages(input)
|
||||
messages = await convert_response_input_to_chat_messages(input)
|
||||
await self._prepend_instructions(messages, instructions)
|
||||
|
||||
# Structured outputs
|
||||
response_format = await _convert_response_text_to_chat_response_format(text)
|
||||
response_format = await convert_response_text_to_chat_response_format(text)
|
||||
|
||||
# Tool setup, TODO: refactor this slightly since this can also yield events
|
||||
chat_tools, mcp_tool_to_server, mcp_list_message = (
|
||||
|
|
|
@ -7,14 +7,37 @@
|
|||
import uuid
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseInputMessageContent,
|
||||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseOutputMessageContent,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseText,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIImageURL,
|
||||
OpenAIJSONSchema,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatJSONObject,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIResponseFormatParam,
|
||||
OpenAIResponseFormatText,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
|
||||
|
||||
|
@ -38,6 +61,102 @@ async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenA
|
|||
)
|
||||
|
||||
|
||||
async def convert_response_content_to_chat_content(
|
||||
content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
|
||||
) -> str | list[OpenAIChatCompletionContentPartParam]:
|
||||
"""
|
||||
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
|
||||
|
||||
The content schemas of each API look similar, but are not exactly the same.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
converted_parts = []
|
||||
for content_part in content:
|
||||
if isinstance(content_part, OpenAIResponseInputMessageContentText):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||
elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||
elif isinstance(content_part, OpenAIResponseInputMessageContentImage):
|
||||
if content_part.image_url:
|
||||
image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail)
|
||||
converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
|
||||
elif isinstance(content_part, str):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support content type '{type(content_part)}' in this context"
|
||||
)
|
||||
return converted_parts
|
||||
|
||||
|
||||
async def convert_response_input_to_chat_messages(
|
||||
input: str | list[OpenAIResponseInput],
|
||||
) -> list[OpenAIMessageParam]:
|
||||
"""
|
||||
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
|
||||
"""
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
if isinstance(input, list):
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
||||
messages.append(
|
||||
OpenAIToolMessageParam(
|
||||
content=input_item.output,
|
||||
tool_call_id=input_item.call_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id=input_item.call_id,
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name=input_item.name,
|
||||
arguments=input_item.arguments,
|
||||
),
|
||||
)
|
||||
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||
else:
|
||||
content = await convert_response_content_to_chat_content(input_item.content)
|
||||
message_type = await get_message_type_by_role(input_item.role)
|
||||
if message_type is None:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
||||
)
|
||||
messages.append(message_type(content=content))
|
||||
else:
|
||||
messages.append(OpenAIUserMessageParam(content=input))
|
||||
return messages
|
||||
|
||||
|
||||
async def convert_response_text_to_chat_response_format(
|
||||
text: OpenAIResponseText,
|
||||
) -> OpenAIResponseFormatParam:
|
||||
"""
|
||||
Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format.
|
||||
"""
|
||||
if not text.format or text.format["type"] == "text":
|
||||
return OpenAIResponseFormatText(type="text")
|
||||
if text.format["type"] == "json_object":
|
||||
return OpenAIResponseFormatJSONObject()
|
||||
if text.format["type"] == "json_schema":
|
||||
return OpenAIResponseFormatJSONSchema(
|
||||
json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"])
|
||||
)
|
||||
raise ValueError(f"Unsupported text format: {text.format}")
|
||||
|
||||
|
||||
async def get_message_type_by_role(role: str):
|
||||
role_to_type = {
|
||||
"user": OpenAIUserMessageParam,
|
||||
"system": OpenAISystemMessageParam,
|
||||
"assistant": OpenAIAssistantMessageParam,
|
||||
"developer": OpenAIDeveloperMessageParam,
|
||||
}
|
||||
return role_to_type.get(role)
|
||||
|
||||
|
||||
def is_function_tool_call(
|
||||
tool_call: OpenAIChatCompletionToolCall,
|
||||
tools: list[OpenAIResponseInputTool],
|
||||
|
|
|
@ -0,0 +1,310 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIResponseFormatJSONObject,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIResponseFormatText,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
||||
convert_chat_choice_to_response_message,
|
||||
convert_response_content_to_chat_content,
|
||||
convert_response_input_to_chat_messages,
|
||||
convert_response_text_to_chat_response_format,
|
||||
get_message_type_by_role,
|
||||
is_function_tool_call,
|
||||
)
|
||||
|
||||
|
||||
class TestConvertChatChoiceToResponseMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_string_content(self):
|
||||
choice = OpenAIChoice(
|
||||
message=OpenAIAssistantMessageParam(content="Test message"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
|
||||
result = await convert_chat_choice_to_response_message(choice)
|
||||
|
||||
assert result.role == "assistant"
|
||||
assert result.status == "completed"
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], OpenAIResponseOutputMessageContentOutputText)
|
||||
assert result.content[0].text == "Test message"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_text_param_content(self):
|
||||
choice = OpenAIChoice(
|
||||
message=OpenAIAssistantMessageParam(
|
||||
content=[OpenAIChatCompletionContentPartTextParam(text="Test text param")]
|
||||
),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await convert_chat_choice_to_response_message(choice)
|
||||
|
||||
assert "does not yet support output content type" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestConvertResponseContentToChatContent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_string_content(self):
|
||||
result = await convert_response_content_to_chat_content("Simple string")
|
||||
assert result == "Simple string"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_text_content_parts(self):
|
||||
content = [
|
||||
OpenAIResponseInputMessageContentText(text="First part"),
|
||||
OpenAIResponseOutputMessageContentOutputText(text="Second part"),
|
||||
]
|
||||
|
||||
result = await convert_response_content_to_chat_content(content)
|
||||
|
||||
assert len(result) == 2
|
||||
assert isinstance(result[0], OpenAIChatCompletionContentPartTextParam)
|
||||
assert result[0].text == "First part"
|
||||
assert isinstance(result[1], OpenAIChatCompletionContentPartTextParam)
|
||||
assert result[1].text == "Second part"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_image_content(self):
|
||||
content = [OpenAIResponseInputMessageContentImage(image_url="https://example.com/image.jpg", detail="high")]
|
||||
|
||||
result = await convert_response_content_to_chat_content(content)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], OpenAIChatCompletionContentPartImageParam)
|
||||
assert result[0].image_url.url == "https://example.com/image.jpg"
|
||||
assert result[0].image_url.detail == "high"
|
||||
|
||||
|
||||
class TestConvertResponseInputToChatMessages:
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_string_input(self):
|
||||
result = await convert_response_input_to_chat_messages("User message")
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], OpenAIUserMessageParam)
|
||||
assert result[0].content == "User message"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_function_tool_call_output(self):
|
||||
input_items = [
|
||||
OpenAIResponseInputFunctionToolCallOutput(
|
||||
output="Tool output",
|
||||
call_id="call_123",
|
||||
)
|
||||
]
|
||||
|
||||
result = await convert_response_input_to_chat_messages(input_items)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], OpenAIToolMessageParam)
|
||||
assert result[0].content == "Tool output"
|
||||
assert result[0].tool_call_id == "call_123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_function_tool_call(self):
|
||||
input_items = [
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
call_id="call_456",
|
||||
name="test_function",
|
||||
arguments='{"param": "value"}',
|
||||
)
|
||||
]
|
||||
|
||||
result = await convert_response_input_to_chat_messages(input_items)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], OpenAIAssistantMessageParam)
|
||||
assert len(result[0].tool_calls) == 1
|
||||
assert result[0].tool_calls[0].id == "call_456"
|
||||
assert result[0].tool_calls[0].function.name == "test_function"
|
||||
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_response_message(self):
|
||||
input_items = [
|
||||
OpenAIResponseMessage(
|
||||
role="user",
|
||||
content=[OpenAIResponseInputMessageContentText(text="User text")],
|
||||
)
|
||||
]
|
||||
|
||||
result = await convert_response_input_to_chat_messages(input_items)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], OpenAIUserMessageParam)
|
||||
# Content should be converted to chat content format
|
||||
assert len(result[0].content) == 1
|
||||
assert result[0].content[0].text == "User text"
|
||||
|
||||
|
||||
class TestConvertResponseTextToChatResponseFormat:
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_text_format(self):
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
|
||||
result = await convert_response_text_to_chat_response_format(text)
|
||||
|
||||
assert isinstance(result, OpenAIResponseFormatText)
|
||||
assert result.type == "text"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_json_object_format(self):
|
||||
text = OpenAIResponseText(format={"type": "json_object"})
|
||||
result = await convert_response_text_to_chat_response_format(text)
|
||||
|
||||
assert isinstance(result, OpenAIResponseFormatJSONObject)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_json_schema_format(self):
|
||||
schema_def = {"type": "object", "properties": {"test": {"type": "string"}}}
|
||||
text = OpenAIResponseText(
|
||||
format={
|
||||
"type": "json_schema",
|
||||
"name": "test_schema",
|
||||
"schema": schema_def,
|
||||
}
|
||||
)
|
||||
result = await convert_response_text_to_chat_response_format(text)
|
||||
|
||||
assert isinstance(result, OpenAIResponseFormatJSONSchema)
|
||||
assert result.json_schema["name"] == "test_schema"
|
||||
assert result.json_schema["schema"] == schema_def
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_text_format(self):
|
||||
text = OpenAIResponseText()
|
||||
result = await convert_response_text_to_chat_response_format(text)
|
||||
|
||||
assert isinstance(result, OpenAIResponseFormatText)
|
||||
assert result.type == "text"
|
||||
|
||||
|
||||
class TestGetMessageTypeByRole:
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_role(self):
|
||||
result = await get_message_type_by_role("user")
|
||||
assert result == OpenAIUserMessageParam
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_system_role(self):
|
||||
result = await get_message_type_by_role("system")
|
||||
assert result == OpenAISystemMessageParam
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assistant_role(self):
|
||||
result = await get_message_type_by_role("assistant")
|
||||
assert result == OpenAIAssistantMessageParam
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_developer_role(self):
|
||||
result = await get_message_type_by_role("developer")
|
||||
assert result == OpenAIDeveloperMessageParam
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_role(self):
|
||||
result = await get_message_type_by_role("unknown")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestIsFunctionToolCall:
|
||||
def test_is_function_tool_call_true(self):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id="call_123",
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name="test_function",
|
||||
arguments="{}",
|
||||
),
|
||||
)
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(
|
||||
type="function", name="test_function", parameters={"type": "object", "properties": {}}
|
||||
),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
]
|
||||
|
||||
result = is_function_tool_call(tool_call, tools)
|
||||
assert result is True
|
||||
|
||||
def test_is_function_tool_call_false_different_name(self):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id="call_123",
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name="other_function",
|
||||
arguments="{}",
|
||||
),
|
||||
)
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(
|
||||
type="function", name="test_function", parameters={"type": "object", "properties": {}}
|
||||
),
|
||||
]
|
||||
|
||||
result = is_function_tool_call(tool_call, tools)
|
||||
assert result is False
|
||||
|
||||
def test_is_function_tool_call_false_no_function(self):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id="call_123",
|
||||
function=None,
|
||||
)
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(
|
||||
type="function", name="test_function", parameters={"type": "object", "properties": {}}
|
||||
),
|
||||
]
|
||||
|
||||
result = is_function_tool_call(tool_call, tools)
|
||||
assert result is False
|
||||
|
||||
def test_is_function_tool_call_false_wrong_type(self):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id="call_123",
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name="web_search",
|
||||
arguments="{}",
|
||||
),
|
||||
)
|
||||
tools = [
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
]
|
||||
|
||||
result = is_function_tool_call(tool_call, tools)
|
||||
assert result is False
|
Loading…
Add table
Add a link
Reference in a new issue