From 65ca85ba6b938bf14a848200ebbf0ad111c837f4 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Wed, 19 Mar 2025 10:36:19 -0700 Subject: [PATCH] fix: Updating `ToolCall.arguments` to allow for json strings that can be decoded on client side (#1685) ### What does this PR do? Currently, `ToolCall.arguments` is a `Dict[str, RecursiveType]`. However, on the client SDK side -- the `RecursiveType` gets deserialized into a number ( both int and float get collapsed ) and hence when params are `int` they get converted to float which might break client side tools that might be doing type checking. Closes: https://github.com/meta-llama/llama-stack/issues/1683 ### Test Plan Stainless changes -- https://github.com/meta-llama/llama-stack-client-python/pull/204 ``` pytest -s -v --stack-config=fireworks tests/integration/agents/test_agents.py --text-model meta-llama/Llama-3.1-8B-Instruct ``` --- docs/_static/llama-stack-spec.html | 132 ++++++++++-------- docs/_static/llama-stack-spec.yaml | 52 +++---- llama_stack/models/llama/datatypes.py | 9 +- .../models/llama/llama3/chat_format.py | 9 +- .../models/llama/llama3/template_data.py | 7 +- .../providers/inline/inference/vllm/vllm.py | 1 + .../remote/inference/sambanova/sambanova.py | 10 +- .../providers/remote/inference/vllm/vllm.py | 8 +- .../utils/inference/openai_compat.py | 14 +- tests/unit/models/test_prompt_adapter.py | 5 +- 10 files changed, 137 insertions(+), 110 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index b32b7cfdf..eb626fc44 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -4159,70 +4159,80 @@ ] }, "arguments": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "integer" - }, - { - "type": "number" - }, - { - "type": "boolean" - }, - { - "type": "null" - }, - { - "type": "array", - "items": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "integer" - }, - { - "type": "number" - }, - { - "type": "boolean" - }, - { - "type": "null" + "oneOf": [ + { + "type": "string" + }, + { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "integer" + }, + { + "type": "number" + }, + { + "type": "boolean" + }, + { + "type": "null" + }, + { + "type": "array", + "items": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "integer" + }, + { + "type": "number" + }, + { + "type": "boolean" + }, + { + "type": "null" + } + ] } - ] - } - }, - { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "integer" - }, - { - "type": "number" - }, - { - "type": "boolean" - }, - { - "type": "null" + }, + { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "integer" + }, + { + "type": "number" + }, + { + "type": "boolean" + }, + { + "type": "null" + } + ] } - ] - } + } + ] } - ] - } + } + ] + }, + "arguments_json": { + "type": "string" } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index eb5d9722e..fa6920381 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -2864,30 +2864,34 @@ components: title: BuiltinTool - type: string arguments: - type: object - additionalProperties: - oneOf: - - type: string - - type: integer - - type: number - - type: boolean - - type: 'null' - - type: array - items: - oneOf: - - type: string - - type: integer - - type: number - - type: boolean - - type: 'null' - - type: object - additionalProperties: - oneOf: - - type: string - - type: integer - - type: number - - type: boolean - - type: 'null' + oneOf: + - type: string + - type: object + additionalProperties: + oneOf: + - type: string + - type: integer + - type: number + - type: boolean + - type: 'null' + - type: array + items: + oneOf: + - type: string + - type: integer + - type: number + - type: boolean + - type: 'null' + - type: object + additionalProperties: + oneOf: + - type: string + - type: integer + - type: number + - type: boolean + - type: 'null' + arguments_json: + type: string additionalProperties: false required: - call_id diff --git a/llama_stack/models/llama/datatypes.py b/llama_stack/models/llama/datatypes.py index b25bf0ea9..9842d7980 100644 --- a/llama_stack/models/llama/datatypes.py +++ b/llama_stack/models/llama/datatypes.py @@ -47,7 +47,14 @@ RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]] class ToolCall(BaseModel): call_id: str tool_name: Union[BuiltinTool, str] - arguments: Dict[str, RecursiveType] + # Plan is to deprecate the Dict in favor of a JSON string + # that is parsed on the client side instead of trying to manage + # the recursive type here. + # Making this a union so that client side can start prepping for this change. + # Eventually, we will remove both the Dict and arguments_json field, + # and arguments will just be a str + arguments: Union[str, Dict[str, RecursiveType]] + arguments_json: Optional[str] = None @field_validator("tool_name", mode="before") @classmethod diff --git a/llama_stack/models/llama/llama3/chat_format.py b/llama_stack/models/llama/llama3/chat_format.py index 011ccb02a..2862f8558 100644 --- a/llama_stack/models/llama/llama3/chat_format.py +++ b/llama_stack/models/llama/llama3/chat_format.py @@ -12,6 +12,7 @@ # the top-level of this source tree. import io +import json import uuid from dataclasses import dataclass from typing import Dict, List, Optional, Tuple @@ -203,9 +204,10 @@ class ChatFormat: # This code tries to handle that case if tool_name in BuiltinTool.__members__: tool_name = BuiltinTool[tool_name] - tool_arguments = { - "query": list(tool_arguments.values())[0], - } + if isinstance(tool_arguments, dict): + tool_arguments = { + "query": list(tool_arguments.values())[0], + } else: builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content) if builtin_tool_info is not None: @@ -229,6 +231,7 @@ class ChatFormat: call_id=call_id, tool_name=tool_name, arguments=tool_arguments, + arguments_json=json.dumps(tool_arguments), ) ) content = "" diff --git a/llama_stack/models/llama/llama3/template_data.py b/llama_stack/models/llama/llama3/template_data.py index aa16aa009..076b4adb4 100644 --- a/llama_stack/models/llama/llama3/template_data.py +++ b/llama_stack/models/llama/llama3/template_data.py @@ -11,11 +11,8 @@ # top-level folder for each specific model found within the models/ directory at # the top-level of this source tree. -from llama_stack.models.llama.datatypes import ( - BuiltinTool, - StopReason, - ToolCall, -) + +from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall from .prompt_templates import ( BuiltinToolGenerator, diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index b59df13d0..256e0f821 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -582,6 +582,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): tool_name=t.function.name, # vLLM function args come back as a string. Llama Stack expects JSON. arguments=json.loads(t.function.arguments), + arguments_json=t.function.arguments, ) for t in vllm_message.tool_calls ], diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index a5e17c2a3..635a42d38 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -42,9 +42,7 @@ from llama_stack.models.llama.datatypes import ( TopKSamplingStrategy, TopPSamplingStrategy, ) -from llama_stack.providers.utils.inference.model_registry import ( - ModelRegistryHelper, -) +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( process_chat_completion_stream_response, ) @@ -293,14 +291,12 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): if not tool_calls: return [] - for call in tool_calls: - call_function_arguments = json.loads(call.function.arguments) - compitable_tool_calls = [ ToolCall( call_id=call.id, tool_name=call.function.name, - arguments=call_function_arguments, + arguments=json.loads(call.function.arguments), + arguments_json=call.function.arguments, ) for call in tool_calls ] diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index f940de7ba..eda1a179c 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -90,15 +90,12 @@ def _convert_to_vllm_tool_calls_in_response( if not tool_calls: return [] - call_function_arguments = None - for call in tool_calls: - call_function_arguments = json.loads(call.function.arguments) - return [ ToolCall( call_id=call.id, tool_name=call.function.name, - arguments=call_function_arguments, + arguments=json.loads(call.function.arguments), + arguments_json=call.function.arguments, ) for call in tool_calls ] @@ -183,6 +180,7 @@ async def _process_vllm_chat_completion_stream_response( call_id=tool_call_buf.call_id, tool_name=tool_call_buf.tool_name, arguments=args, + arguments_json=args_str, ), parse_status=ToolCallParseStatus.succeeded, ), diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 2a362f8cb..b264c7312 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -529,7 +529,11 @@ async def convert_message_to_openai_dict_new( ) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]: async def impl( content_: InterleavedContent, - ) -> Union[str, OpenAIChatCompletionContentPartParam, List[OpenAIChatCompletionContentPartParam]]: + ) -> Union[ + str, + OpenAIChatCompletionContentPartParam, + List[OpenAIChatCompletionContentPartParam], + ]: # Llama Stack and OpenAI spec match for str and text input if isinstance(content_, str): return content_ @@ -570,7 +574,7 @@ async def convert_message_to_openai_dict_new( OpenAIChatCompletionMessageToolCall( id=tool.call_id, function=OpenAIFunction( - name=tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value, + name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value), arguments=json.dumps(tool.arguments), ), type="function", @@ -609,6 +613,7 @@ def convert_tool_call( call_id=tool_call.id, tool_name=tool_call.function.name, arguments=json.loads(tool_call.function.arguments), + arguments_json=tool_call.function.arguments, ) except Exception: return UnparseableToolCall( @@ -759,6 +764,7 @@ def _convert_openai_tool_calls( call_id=call.id, tool_name=call.function.name, arguments=json.loads(call.function.arguments), + arguments_json=call.function.arguments, ) for call in tool_calls ] @@ -890,7 +896,8 @@ async def convert_openai_chat_completion_stream( # ChatCompletionResponseEvent only supports one per stream if len(choice.delta.tool_calls) > 1: warnings.warn( - "multiple tool calls found in a single delta, using the first, ignoring the rest", stacklevel=2 + "multiple tool calls found in a single delta, using the first, ignoring the rest", + stacklevel=2, ) if not enable_incremental_tool_calls: @@ -971,6 +978,7 @@ async def convert_openai_chat_completion_stream( call_id=buffer["call_id"], tool_name=buffer["name"], arguments=arguments, + arguments_json=buffer["arguments"], ) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( diff --git a/tests/unit/models/test_prompt_adapter.py b/tests/unit/models/test_prompt_adapter.py index c3755e2cb..0e2780e50 100644 --- a/tests/unit/models/test_prompt_adapter.py +++ b/tests/unit/models/test_prompt_adapter.py @@ -165,7 +165,10 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): request.model = MODEL request.tool_config.tool_prompt_format = ToolPromptFormat.json prompt = await chat_completion_request_to_prompt(request, request.model) - self.assertIn('{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', prompt) + self.assertIn( + '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', + prompt, + ) async def test_user_provided_system_message(self): content = "Hello !"