new param arguments_json in ToolCall

This commit is contained in:
Hardik Shah 2025-03-18 13:44:29 -07:00
parent 37f155e41d
commit 549096b264
10 changed files with 132 additions and 107 deletions

View file

@ -4229,70 +4229,80 @@
] ]
}, },
"arguments": { "arguments": {
"type": "object", "oneOf": [
"additionalProperties": { {
"oneOf": [ "type": "string"
{ },
"type": "string" {
}, "type": "object",
{ "additionalProperties": {
"type": "integer" "oneOf": [
}, {
{ "type": "string"
"type": "number" },
}, {
{ "type": "integer"
"type": "boolean" },
}, {
{ "type": "number"
"type": "null" },
}, {
{ "type": "boolean"
"type": "array", },
"items": { {
"oneOf": [ "type": "null"
{ },
"type": "string" {
}, "type": "array",
{ "items": {
"type": "integer" "oneOf": [
}, {
{ "type": "string"
"type": "number" },
}, {
{ "type": "integer"
"type": "boolean" },
}, {
{ "type": "number"
"type": "null" },
{
"type": "boolean"
},
{
"type": "null"
}
]
} }
] },
} {
}, "type": "object",
{ "additionalProperties": {
"type": "object", "oneOf": [
"additionalProperties": { {
"oneOf": [ "type": "string"
{ },
"type": "string" {
}, "type": "integer"
{ },
"type": "integer" {
}, "type": "number"
{ },
"type": "number" {
}, "type": "boolean"
{ },
"type": "boolean" {
}, "type": "null"
{ }
"type": "null" ]
} }
] }
} ]
} }
] }
} ]
},
"arguments_json": {
"type": "string"
} }
}, },
"additionalProperties": false, "additionalProperties": false,

View file

@ -2884,30 +2884,34 @@ components:
title: BuiltinTool title: BuiltinTool
- type: string - type: string
arguments: arguments:
type: object oneOf:
additionalProperties: - type: string
oneOf: - type: object
- type: string additionalProperties:
- type: integer oneOf:
- type: number - type: string
- type: boolean - type: integer
- type: 'null' - type: number
- type: array - type: boolean
items: - type: 'null'
oneOf: - type: array
- type: string items:
- type: integer oneOf:
- type: number - type: string
- type: boolean - type: integer
- type: 'null' - type: number
- type: object - type: boolean
additionalProperties: - type: 'null'
oneOf: - type: object
- type: string additionalProperties:
- type: integer oneOf:
- type: number - type: string
- type: boolean - type: integer
- type: 'null' - type: number
- type: boolean
- type: 'null'
arguments_json:
type: string
additionalProperties: false additionalProperties: false
required: required:
- call_id - call_id

View file

@ -47,7 +47,9 @@ RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
class ToolCall(BaseModel): class ToolCall(BaseModel):
call_id: str call_id: str
tool_name: Union[BuiltinTool, str] tool_name: Union[BuiltinTool, str]
arguments: Dict[str, RecursiveType] arguments: Union[str, Dict[str, RecursiveType]]
# Temporary field for backwards compatibility
arguments_json: Optional[str] = None
@field_validator("tool_name", mode="before") @field_validator("tool_name", mode="before")
@classmethod @classmethod

View file

@ -12,6 +12,7 @@
# the top-level of this source tree. # the top-level of this source tree.
import io import io
import json
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -203,9 +204,10 @@ class ChatFormat:
# This code tries to handle that case # This code tries to handle that case
if tool_name in BuiltinTool.__members__: if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name] tool_name = BuiltinTool[tool_name]
tool_arguments = { if isinstance(tool_arguments, dict):
"query": list(tool_arguments.values())[0], tool_arguments = {
} "query": list(tool_arguments.values())[0],
}
else: else:
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content) builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
if builtin_tool_info is not None: if builtin_tool_info is not None:
@ -229,6 +231,7 @@ class ChatFormat:
call_id=call_id, call_id=call_id,
tool_name=tool_name, tool_name=tool_name,
arguments=tool_arguments, arguments=tool_arguments,
arguments_json=json.dumps(tool_arguments),
) )
) )
content = "" content = ""

View file

@ -11,11 +11,8 @@
# top-level folder for each specific model found within the models/ directory at # top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree. # the top-level of this source tree.
from llama_stack.models.llama.datatypes import (
BuiltinTool, from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
StopReason,
ToolCall,
)
from .prompt_templates import ( from .prompt_templates import (
BuiltinToolGenerator, BuiltinToolGenerator,

View file

@ -582,6 +582,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
tool_name=t.function.name, tool_name=t.function.name,
# vLLM function args come back as a string. Llama Stack expects JSON. # vLLM function args come back as a string. Llama Stack expects JSON.
arguments=json.loads(t.function.arguments), arguments=json.loads(t.function.arguments),
arguments_json=t.function.arguments,
) )
for t in vllm_message.tool_calls for t in vllm_message.tool_calls
], ],

View file

@ -42,9 +42,7 @@ from llama_stack.models.llama.datatypes import (
TopKSamplingStrategy, TopKSamplingStrategy,
TopPSamplingStrategy, TopPSamplingStrategy,
) )
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
process_chat_completion_stream_response, process_chat_completion_stream_response,
) )
@ -293,14 +291,12 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
if not tool_calls: if not tool_calls:
return [] return []
for call in tool_calls:
call_function_arguments = json.loads(call.function.arguments)
compitable_tool_calls = [ compitable_tool_calls = [
ToolCall( ToolCall(
call_id=call.id, call_id=call.id,
tool_name=call.function.name, tool_name=call.function.name,
arguments=call_function_arguments, arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
) )
for call in tool_calls for call in tool_calls
] ]

View file

@ -97,7 +97,8 @@ def _convert_to_vllm_tool_calls_in_response(
ToolCall( ToolCall(
call_id=call.id, call_id=call.id,
tool_name=call.function.name, tool_name=call.function.name,
arguments=call_function_arguments, arguments=json.loads(call.function.arguments),
arguments_json=call_function_arguments,
) )
for call in tool_calls for call in tool_calls
] ]
@ -181,7 +182,7 @@ async def _process_vllm_chat_completion_stream_response(
tool_call=ToolCall( tool_call=ToolCall(
call_id=tool_call_buf.call_id, call_id=tool_call_buf.call_id,
tool_name=tool_call_buf.tool_name, tool_name=tool_call_buf.tool_name,
arguments=args, arguments=args_str,
), ),
parse_status=ToolCallParseStatus.succeeded, parse_status=ToolCallParseStatus.succeeded,
), ),

View file

@ -529,7 +529,11 @@ async def convert_message_to_openai_dict_new(
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]: ) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
async def impl( async def impl(
content_: InterleavedContent, content_: InterleavedContent,
) -> Union[str, OpenAIChatCompletionContentPartParam, List[OpenAIChatCompletionContentPartParam]]: ) -> Union[
str,
OpenAIChatCompletionContentPartParam,
List[OpenAIChatCompletionContentPartParam],
]:
# Llama Stack and OpenAI spec match for str and text input # Llama Stack and OpenAI spec match for str and text input
if isinstance(content_, str): if isinstance(content_, str):
return content_ return content_
@ -570,7 +574,7 @@ async def convert_message_to_openai_dict_new(
OpenAIChatCompletionMessageToolCall( OpenAIChatCompletionMessageToolCall(
id=tool.call_id, id=tool.call_id,
function=OpenAIFunction( function=OpenAIFunction(
name=tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value, name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
arguments=json.dumps(tool.arguments), arguments=json.dumps(tool.arguments),
), ),
type="function", type="function",
@ -609,6 +613,7 @@ def convert_tool_call(
call_id=tool_call.id, call_id=tool_call.id,
tool_name=tool_call.function.name, tool_name=tool_call.function.name,
arguments=json.loads(tool_call.function.arguments), arguments=json.loads(tool_call.function.arguments),
arguments_json=tool_call.function.arguments,
) )
except Exception: except Exception:
return UnparseableToolCall( return UnparseableToolCall(
@ -759,6 +764,7 @@ def _convert_openai_tool_calls(
call_id=call.id, call_id=call.id,
tool_name=call.function.name, tool_name=call.function.name,
arguments=json.loads(call.function.arguments), arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
) )
for call in tool_calls for call in tool_calls
] ]
@ -890,7 +896,8 @@ async def convert_openai_chat_completion_stream(
# ChatCompletionResponseEvent only supports one per stream # ChatCompletionResponseEvent only supports one per stream
if len(choice.delta.tool_calls) > 1: if len(choice.delta.tool_calls) > 1:
warnings.warn( warnings.warn(
"multiple tool calls found in a single delta, using the first, ignoring the rest", stacklevel=2 "multiple tool calls found in a single delta, using the first, ignoring the rest",
stacklevel=2,
) )
if not enable_incremental_tool_calls: if not enable_incremental_tool_calls:
@ -971,6 +978,7 @@ async def convert_openai_chat_completion_stream(
call_id=buffer["call_id"], call_id=buffer["call_id"],
tool_name=buffer["name"], tool_name=buffer["name"],
arguments=arguments, arguments=arguments,
arguments_json=buffer["arguments"],
) )
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(

View file

@ -165,7 +165,10 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
request.model = MODEL request.model = MODEL
request.tool_config.tool_prompt_format = ToolPromptFormat.json request.tool_config.tool_prompt_format = ToolPromptFormat.json
prompt = await chat_completion_request_to_prompt(request, request.model) prompt = await chat_completion_request_to_prompt(request, request.model)
self.assertIn('{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', prompt) self.assertIn(
'{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}',
prompt,
)
async def test_user_provided_system_message(self): async def test_user_provided_system_message(self):
content = "Hello !" content = "Hello !"