fix: Updating ToolCall.arguments to allow for json strings that can be decoded on client side (#1685)

### What does this PR do?

Currently, `ToolCall.arguments` is a `Dict[str, RecursiveType]`.
However, on the client SDK side -- the `RecursiveType` gets deserialized
into a number ( both int and float get collapsed ) and hence when params
are `int` they get converted to float which might break client side
tools that might be doing type checking.

Closes: https://github.com/meta-llama/llama-stack/issues/1683

### Test Plan
Stainless changes --
https://github.com/meta-llama/llama-stack-client-python/pull/204
```
pytest -s -v --stack-config=fireworks tests/integration/agents/test_agents.py  --text-model meta-llama/Llama-3.1-8B-Instruct
```
This commit is contained in:
Hardik Shah 2025-03-19 10:36:19 -07:00 committed by GitHub
parent 113f3a259c
commit 65ca85ba6b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 137 additions and 110 deletions

View file

@ -4159,70 +4159,80 @@
] ]
}, },
"arguments": { "arguments": {
"type": "object", "oneOf": [
"additionalProperties": { {
"oneOf": [ "type": "string"
{ },
"type": "string" {
}, "type": "object",
{ "additionalProperties": {
"type": "integer" "oneOf": [
}, {
{ "type": "string"
"type": "number" },
}, {
{ "type": "integer"
"type": "boolean" },
}, {
{ "type": "number"
"type": "null" },
}, {
{ "type": "boolean"
"type": "array", },
"items": { {
"oneOf": [ "type": "null"
{ },
"type": "string" {
}, "type": "array",
{ "items": {
"type": "integer" "oneOf": [
}, {
{ "type": "string"
"type": "number" },
}, {
{ "type": "integer"
"type": "boolean" },
}, {
{ "type": "number"
"type": "null" },
{
"type": "boolean"
},
{
"type": "null"
}
]
} }
] },
} {
}, "type": "object",
{ "additionalProperties": {
"type": "object", "oneOf": [
"additionalProperties": { {
"oneOf": [ "type": "string"
{ },
"type": "string" {
}, "type": "integer"
{ },
"type": "integer" {
}, "type": "number"
{ },
"type": "number" {
}, "type": "boolean"
{ },
"type": "boolean" {
}, "type": "null"
{ }
"type": "null" ]
} }
] }
} ]
} }
] }
} ]
},
"arguments_json": {
"type": "string"
} }
}, },
"additionalProperties": false, "additionalProperties": false,

View file

@ -2864,30 +2864,34 @@ components:
title: BuiltinTool title: BuiltinTool
- type: string - type: string
arguments: arguments:
type: object oneOf:
additionalProperties: - type: string
oneOf: - type: object
- type: string additionalProperties:
- type: integer oneOf:
- type: number - type: string
- type: boolean - type: integer
- type: 'null' - type: number
- type: array - type: boolean
items: - type: 'null'
oneOf: - type: array
- type: string items:
- type: integer oneOf:
- type: number - type: string
- type: boolean - type: integer
- type: 'null' - type: number
- type: object - type: boolean
additionalProperties: - type: 'null'
oneOf: - type: object
- type: string additionalProperties:
- type: integer oneOf:
- type: number - type: string
- type: boolean - type: integer
- type: 'null' - type: number
- type: boolean
- type: 'null'
arguments_json:
type: string
additionalProperties: false additionalProperties: false
required: required:
- call_id - call_id

View file

@ -47,7 +47,14 @@ RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
class ToolCall(BaseModel): class ToolCall(BaseModel):
call_id: str call_id: str
tool_name: Union[BuiltinTool, str] tool_name: Union[BuiltinTool, str]
arguments: Dict[str, RecursiveType] # Plan is to deprecate the Dict in favor of a JSON string
# that is parsed on the client side instead of trying to manage
# the recursive type here.
# Making this a union so that client side can start prepping for this change.
# Eventually, we will remove both the Dict and arguments_json field,
# and arguments will just be a str
arguments: Union[str, Dict[str, RecursiveType]]
arguments_json: Optional[str] = None
@field_validator("tool_name", mode="before") @field_validator("tool_name", mode="before")
@classmethod @classmethod

View file

@ -12,6 +12,7 @@
# the top-level of this source tree. # the top-level of this source tree.
import io import io
import json
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -203,9 +204,10 @@ class ChatFormat:
# This code tries to handle that case # This code tries to handle that case
if tool_name in BuiltinTool.__members__: if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name] tool_name = BuiltinTool[tool_name]
tool_arguments = { if isinstance(tool_arguments, dict):
"query": list(tool_arguments.values())[0], tool_arguments = {
} "query": list(tool_arguments.values())[0],
}
else: else:
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content) builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
if builtin_tool_info is not None: if builtin_tool_info is not None:
@ -229,6 +231,7 @@ class ChatFormat:
call_id=call_id, call_id=call_id,
tool_name=tool_name, tool_name=tool_name,
arguments=tool_arguments, arguments=tool_arguments,
arguments_json=json.dumps(tool_arguments),
) )
) )
content = "" content = ""

View file

@ -11,11 +11,8 @@
# top-level folder for each specific model found within the models/ directory at # top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree. # the top-level of this source tree.
from llama_stack.models.llama.datatypes import (
BuiltinTool, from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
StopReason,
ToolCall,
)
from .prompt_templates import ( from .prompt_templates import (
BuiltinToolGenerator, BuiltinToolGenerator,

View file

@ -582,6 +582,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
tool_name=t.function.name, tool_name=t.function.name,
# vLLM function args come back as a string. Llama Stack expects JSON. # vLLM function args come back as a string. Llama Stack expects JSON.
arguments=json.loads(t.function.arguments), arguments=json.loads(t.function.arguments),
arguments_json=t.function.arguments,
) )
for t in vllm_message.tool_calls for t in vllm_message.tool_calls
], ],

View file

@ -42,9 +42,7 @@ from llama_stack.models.llama.datatypes import (
TopKSamplingStrategy, TopKSamplingStrategy,
TopPSamplingStrategy, TopPSamplingStrategy,
) )
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
process_chat_completion_stream_response, process_chat_completion_stream_response,
) )
@ -293,14 +291,12 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
if not tool_calls: if not tool_calls:
return [] return []
for call in tool_calls:
call_function_arguments = json.loads(call.function.arguments)
compitable_tool_calls = [ compitable_tool_calls = [
ToolCall( ToolCall(
call_id=call.id, call_id=call.id,
tool_name=call.function.name, tool_name=call.function.name,
arguments=call_function_arguments, arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
) )
for call in tool_calls for call in tool_calls
] ]

View file

@ -90,15 +90,12 @@ def _convert_to_vllm_tool_calls_in_response(
if not tool_calls: if not tool_calls:
return [] return []
call_function_arguments = None
for call in tool_calls:
call_function_arguments = json.loads(call.function.arguments)
return [ return [
ToolCall( ToolCall(
call_id=call.id, call_id=call.id,
tool_name=call.function.name, tool_name=call.function.name,
arguments=call_function_arguments, arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
) )
for call in tool_calls for call in tool_calls
] ]
@ -183,6 +180,7 @@ async def _process_vllm_chat_completion_stream_response(
call_id=tool_call_buf.call_id, call_id=tool_call_buf.call_id,
tool_name=tool_call_buf.tool_name, tool_name=tool_call_buf.tool_name,
arguments=args, arguments=args,
arguments_json=args_str,
), ),
parse_status=ToolCallParseStatus.succeeded, parse_status=ToolCallParseStatus.succeeded,
), ),

View file

@ -529,7 +529,11 @@ async def convert_message_to_openai_dict_new(
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]: ) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
async def impl( async def impl(
content_: InterleavedContent, content_: InterleavedContent,
) -> Union[str, OpenAIChatCompletionContentPartParam, List[OpenAIChatCompletionContentPartParam]]: ) -> Union[
str,
OpenAIChatCompletionContentPartParam,
List[OpenAIChatCompletionContentPartParam],
]:
# Llama Stack and OpenAI spec match for str and text input # Llama Stack and OpenAI spec match for str and text input
if isinstance(content_, str): if isinstance(content_, str):
return content_ return content_
@ -570,7 +574,7 @@ async def convert_message_to_openai_dict_new(
OpenAIChatCompletionMessageToolCall( OpenAIChatCompletionMessageToolCall(
id=tool.call_id, id=tool.call_id,
function=OpenAIFunction( function=OpenAIFunction(
name=tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value, name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
arguments=json.dumps(tool.arguments), arguments=json.dumps(tool.arguments),
), ),
type="function", type="function",
@ -609,6 +613,7 @@ def convert_tool_call(
call_id=tool_call.id, call_id=tool_call.id,
tool_name=tool_call.function.name, tool_name=tool_call.function.name,
arguments=json.loads(tool_call.function.arguments), arguments=json.loads(tool_call.function.arguments),
arguments_json=tool_call.function.arguments,
) )
except Exception: except Exception:
return UnparseableToolCall( return UnparseableToolCall(
@ -759,6 +764,7 @@ def _convert_openai_tool_calls(
call_id=call.id, call_id=call.id,
tool_name=call.function.name, tool_name=call.function.name,
arguments=json.loads(call.function.arguments), arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
) )
for call in tool_calls for call in tool_calls
] ]
@ -890,7 +896,8 @@ async def convert_openai_chat_completion_stream(
# ChatCompletionResponseEvent only supports one per stream # ChatCompletionResponseEvent only supports one per stream
if len(choice.delta.tool_calls) > 1: if len(choice.delta.tool_calls) > 1:
warnings.warn( warnings.warn(
"multiple tool calls found in a single delta, using the first, ignoring the rest", stacklevel=2 "multiple tool calls found in a single delta, using the first, ignoring the rest",
stacklevel=2,
) )
if not enable_incremental_tool_calls: if not enable_incremental_tool_calls:
@ -971,6 +978,7 @@ async def convert_openai_chat_completion_stream(
call_id=buffer["call_id"], call_id=buffer["call_id"],
tool_name=buffer["name"], tool_name=buffer["name"],
arguments=arguments, arguments=arguments,
arguments_json=buffer["arguments"],
) )
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(

View file

@ -165,7 +165,10 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
request.model = MODEL request.model = MODEL
request.tool_config.tool_prompt_format = ToolPromptFormat.json request.tool_config.tool_prompt_format = ToolPromptFormat.json
prompt = await chat_completion_request_to_prompt(request, request.model) prompt = await chat_completion_request_to_prompt(request, request.model)
self.assertIn('{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', prompt) self.assertIn(
'{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}',
prompt,
)
async def test_user_provided_system_message(self): async def test_user_provided_system_message(self):
content = "Hello !" content = "Hello !"