mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 06:00:48 +00:00
fix: use ChatCompletionMessageFunctionToolCall (#3142)
The OpenAI compatibility layer was incorrectly importing ChatCompletionMessageToolCallParam instead of the ChatCompletionMessageFunctionToolCall class. This caused "Cannot instantiate typing.Union" errors when processing agent requests with tool calls. Closes: #3141 Signed-off-by: Derek Higgins <derekh@redhat.com>
This commit is contained in:
parent
ee7631b6cf
commit
c15cc7ed77
2 changed files with 45 additions and 5 deletions
|
@ -31,15 +31,15 @@ from openai.types.chat import (
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
|
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
|
||||||
)
|
)
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionMessageFunctionToolCall as OpenAIChatCompletionMessageFunctionToolCall,
|
||||||
|
)
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
|
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
|
||||||
)
|
)
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
ChatCompletionMessageToolCall,
|
ChatCompletionMessageToolCall,
|
||||||
)
|
)
|
||||||
from openai.types.chat import (
|
|
||||||
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall,
|
|
||||||
)
|
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
|
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
|
||||||
)
|
)
|
||||||
|
@ -633,7 +633,7 @@ async def convert_message_to_openai_dict_new(
|
||||||
)
|
)
|
||||||
elif isinstance(message, CompletionMessage):
|
elif isinstance(message, CompletionMessage):
|
||||||
tool_calls = [
|
tool_calls = [
|
||||||
OpenAIChatCompletionMessageToolCall(
|
OpenAIChatCompletionMessageFunctionToolCall(
|
||||||
id=tool.call_id,
|
id=tool.call_id,
|
||||||
function=OpenAIFunction(
|
function=OpenAIFunction(
|
||||||
name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
|
name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
|
||||||
|
@ -903,7 +903,7 @@ def _convert_openai_request_response_format(
|
||||||
|
|
||||||
|
|
||||||
def _convert_openai_tool_calls(
|
def _convert_openai_tool_calls(
|
||||||
tool_calls: list[OpenAIChatCompletionMessageToolCall],
|
tool_calls: list[OpenAIChatCompletionMessageFunctionToolCall],
|
||||||
) -> list[ToolCall]:
|
) -> list[ToolCall]:
|
||||||
"""
|
"""
|
||||||
Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall.
|
Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall.
|
||||||
|
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
|
||||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
convert_message_to_openai_dict,
|
convert_message_to_openai_dict,
|
||||||
|
convert_message_to_openai_dict_new,
|
||||||
openai_messages_to_messages,
|
openai_messages_to_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -182,3 +183,42 @@ def test_user_message_accepts_images():
|
||||||
assert len(msg.content) == 2
|
assert len(msg.content) == 2
|
||||||
assert msg.content[0].text == "Describe this image:"
|
assert msg.content[0].text == "Describe this image:"
|
||||||
assert msg.content[1].image_url.url == "http://example.com/image.jpg"
|
assert msg.content[1].image_url.url == "http://example.com/image.jpg"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_convert_message_to_openai_dict_new_user_message():
|
||||||
|
"""Test convert_message_to_openai_dict_new with UserMessage."""
|
||||||
|
message = UserMessage(content="Hello, world!", role="user")
|
||||||
|
result = await convert_message_to_openai_dict_new(message)
|
||||||
|
|
||||||
|
assert result["role"] == "user"
|
||||||
|
assert result["content"] == "Hello, world!"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_convert_message_to_openai_dict_new_completion_message_with_tool_calls():
|
||||||
|
"""Test convert_message_to_openai_dict_new with CompletionMessage containing tool calls."""
|
||||||
|
message = CompletionMessage(
|
||||||
|
content="I'll help you find the weather.",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
call_id="call_123",
|
||||||
|
tool_name="get_weather",
|
||||||
|
arguments={"city": "Sligo"},
|
||||||
|
arguments_json='{"city": "Sligo"}',
|
||||||
|
)
|
||||||
|
],
|
||||||
|
stop_reason=StopReason.end_of_turn,
|
||||||
|
)
|
||||||
|
result = await convert_message_to_openai_dict_new(message)
|
||||||
|
|
||||||
|
# This would have failed with "Cannot instantiate typing.Union" before the fix
|
||||||
|
assert result["role"] == "assistant"
|
||||||
|
assert result["content"] == "I'll help you find the weather."
|
||||||
|
assert "tool_calls" in result
|
||||||
|
assert result["tool_calls"] is not None
|
||||||
|
assert len(result["tool_calls"]) == 1
|
||||||
|
|
||||||
|
tool_call = result["tool_calls"][0]
|
||||||
|
assert tool_call.id == "call_123"
|
||||||
|
assert tool_call.type == "function"
|
||||||
|
assert tool_call.function.name == "get_weather"
|
||||||
|
assert tool_call.function.arguments == '{"city": "Sligo"}'
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue