forked from phoenix-oss/llama-stack-mirror
fix: Updating ToolCall.arguments
to allow for json strings that can be decoded on client side (#1685)
### What does this PR do? Currently, `ToolCall.arguments` is a `Dict[str, RecursiveType]`. However, on the client SDK side -- the `RecursiveType` gets deserialized into a number ( both int and float get collapsed ) and hence when params are `int` they get converted to float which might break client side tools that might be doing type checking. Closes: https://github.com/meta-llama/llama-stack/issues/1683 ### Test Plan Stainless changes -- https://github.com/meta-llama/llama-stack-client-python/pull/204 ``` pytest -s -v --stack-config=fireworks tests/integration/agents/test_agents.py --text-model meta-llama/Llama-3.1-8B-Instruct ```
This commit is contained in:
parent
113f3a259c
commit
65ca85ba6b
10 changed files with 137 additions and 110 deletions
|
@ -529,7 +529,11 @@ async def convert_message_to_openai_dict_new(
|
|||
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
|
||||
async def impl(
|
||||
content_: InterleavedContent,
|
||||
) -> Union[str, OpenAIChatCompletionContentPartParam, List[OpenAIChatCompletionContentPartParam]]:
|
||||
) -> Union[
|
||||
str,
|
||||
OpenAIChatCompletionContentPartParam,
|
||||
List[OpenAIChatCompletionContentPartParam],
|
||||
]:
|
||||
# Llama Stack and OpenAI spec match for str and text input
|
||||
if isinstance(content_, str):
|
||||
return content_
|
||||
|
@ -570,7 +574,7 @@ async def convert_message_to_openai_dict_new(
|
|||
OpenAIChatCompletionMessageToolCall(
|
||||
id=tool.call_id,
|
||||
function=OpenAIFunction(
|
||||
name=tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value,
|
||||
name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
|
||||
arguments=json.dumps(tool.arguments),
|
||||
),
|
||||
type="function",
|
||||
|
@ -609,6 +613,7 @@ def convert_tool_call(
|
|||
call_id=tool_call.id,
|
||||
tool_name=tool_call.function.name,
|
||||
arguments=json.loads(tool_call.function.arguments),
|
||||
arguments_json=tool_call.function.arguments,
|
||||
)
|
||||
except Exception:
|
||||
return UnparseableToolCall(
|
||||
|
@ -759,6 +764,7 @@ def _convert_openai_tool_calls(
|
|||
call_id=call.id,
|
||||
tool_name=call.function.name,
|
||||
arguments=json.loads(call.function.arguments),
|
||||
arguments_json=call.function.arguments,
|
||||
)
|
||||
for call in tool_calls
|
||||
]
|
||||
|
@ -890,7 +896,8 @@ async def convert_openai_chat_completion_stream(
|
|||
# ChatCompletionResponseEvent only supports one per stream
|
||||
if len(choice.delta.tool_calls) > 1:
|
||||
warnings.warn(
|
||||
"multiple tool calls found in a single delta, using the first, ignoring the rest", stacklevel=2
|
||||
"multiple tool calls found in a single delta, using the first, ignoring the rest",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if not enable_incremental_tool_calls:
|
||||
|
@ -971,6 +978,7 @@ async def convert_openai_chat_completion_stream(
|
|||
call_id=buffer["call_id"],
|
||||
tool_name=buffer["name"],
|
||||
arguments=arguments,
|
||||
arguments_json=buffer["arguments"],
|
||||
)
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue