new param arguments_json in ToolCall

This commit is contained in:
Hardik Shah 2025-03-18 13:44:29 -07:00
parent 37f155e41d
commit 549096b264
10 changed files with 132 additions and 107 deletions

View file

@ -47,7 +47,9 @@ RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
class ToolCall(BaseModel):
call_id: str
tool_name: Union[BuiltinTool, str]
arguments: Dict[str, RecursiveType]
arguments: Union[str, Dict[str, RecursiveType]]
# Temporary field for backwards compatibility
arguments_json: Optional[str] = None
@field_validator("tool_name", mode="before")
@classmethod

View file

@ -12,6 +12,7 @@
# the top-level of this source tree.
import io
import json
import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
@ -203,9 +204,10 @@ class ChatFormat:
# This code tries to handle that case
if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
tool_arguments = {
"query": list(tool_arguments.values())[0],
}
if isinstance(tool_arguments, dict):
tool_arguments = {
"query": list(tool_arguments.values())[0],
}
else:
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
if builtin_tool_info is not None:
@ -229,6 +231,7 @@ class ChatFormat:
call_id=call_id,
tool_name=tool_name,
arguments=tool_arguments,
arguments_json=json.dumps(tool_arguments),
)
)
content = ""

View file

@ -11,11 +11,8 @@
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from llama_stack.models.llama.datatypes import (
BuiltinTool,
StopReason,
ToolCall,
)
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from .prompt_templates import (
BuiltinToolGenerator,