mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 22:42:25 +00:00
chore(api): add mypy coverage to tools_utils
Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
51b179e1c5
commit
5c1671ac8c
2 changed files with 77 additions and 17 deletions
|
|
@ -167,7 +167,7 @@ def parse_llama_tool_call_format(input_string):
|
|||
class ToolUtils:
|
||||
@staticmethod
|
||||
def is_builtin_tool_call(message_body: str) -> bool:
|
||||
match = re.search(ToolUtils.BUILTIN_TOOL_PATTERN, message_body)
|
||||
match = re.search(BUILTIN_TOOL_PATTERN, message_body)
|
||||
return match is not None
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -184,7 +184,7 @@ class ToolUtils:
|
|||
return None
|
||||
|
||||
@staticmethod
|
||||
def maybe_extract_custom_tool_call(message_body: str) -> tuple[str, str] | None:
|
||||
def maybe_extract_custom_tool_call(message_body: str) -> tuple[str, dict] | None:
|
||||
# NOTE: Custom function too calls are still experimental
|
||||
# Sometimes, response is of the form
|
||||
# {"type": "function", "name": "function_name", "parameters": {...}
|
||||
|
|
@ -197,7 +197,11 @@ class ToolUtils:
|
|||
tool_name = match.group("function_name")
|
||||
query = match.group("args")
|
||||
try:
|
||||
return tool_name, json.loads(query.replace("'", '"'))
|
||||
parsed_args = json.loads(query.replace("'", '"'))
|
||||
if isinstance(parsed_args, dict):
|
||||
return tool_name, parsed_args
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
print("Exception while parsing json query for custom tool call", query, e)
|
||||
return None
|
||||
|
|
@ -208,12 +212,22 @@ class ToolUtils:
|
|||
):
|
||||
function_name = response["name"]
|
||||
args = response["parameters"]
|
||||
return function_name, args
|
||||
if isinstance(args, dict):
|
||||
return function_name, args
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
elif function_calls := parse_llama_tool_call_format(message_body):
|
||||
# FIXME: Enable multiple tool calls
|
||||
return function_calls[0]
|
||||
if function_calls and len(function_calls) > 0:
|
||||
func_name, args_dict = function_calls[0]
|
||||
if isinstance(args_dict, dict):
|
||||
return func_name, args_dict
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
logger.debug(f"Did not parse tool call from message body: {message_body}")
|
||||
return None
|
||||
|
|
@ -221,29 +235,64 @@ class ToolUtils:
|
|||
@staticmethod
|
||||
def encode_tool_call(t: ToolCall, tool_prompt_format: ToolPromptFormat) -> str:
|
||||
if t.tool_name == BuiltinTool.brave_search:
|
||||
q = t.arguments["query"]
|
||||
return f'brave_search.call(query="{q}")'
|
||||
if isinstance(t.arguments, dict):
|
||||
q = t.arguments["query"]
|
||||
return f'brave_search.call(query="{q}")'
|
||||
else:
|
||||
# Handle string arguments
|
||||
return f'brave_search.call(query="{t.arguments}")'
|
||||
elif t.tool_name == BuiltinTool.wolfram_alpha:
|
||||
q = t.arguments["query"]
|
||||
return f'wolfram_alpha.call(query="{q}")'
|
||||
if isinstance(t.arguments, dict):
|
||||
q = t.arguments["query"]
|
||||
return f'wolfram_alpha.call(query="{q}")'
|
||||
else:
|
||||
# Handle string arguments
|
||||
return f'wolfram_alpha.call(query="{t.arguments}")'
|
||||
elif t.tool_name == BuiltinTool.photogen:
|
||||
q = t.arguments["query"]
|
||||
return f'photogen.call(query="{q}")'
|
||||
if isinstance(t.arguments, dict):
|
||||
q = t.arguments["query"]
|
||||
return f'photogen.call(query="{q}")'
|
||||
else:
|
||||
# Handle string arguments
|
||||
return f'photogen.call(query="{t.arguments}")'
|
||||
elif t.tool_name == BuiltinTool.code_interpreter:
|
||||
return t.arguments["code"]
|
||||
if isinstance(t.arguments, dict):
|
||||
code = t.arguments["code"]
|
||||
return str(code)
|
||||
else:
|
||||
# Handle string arguments
|
||||
return str(t.arguments)
|
||||
else:
|
||||
fname = t.tool_name
|
||||
|
||||
if tool_prompt_format == ToolPromptFormat.json:
|
||||
# For JSON format, we need to handle both string and dict arguments
|
||||
if isinstance(t.arguments, str):
|
||||
# Try to parse string as JSON
|
||||
try:
|
||||
parsed_args = json.loads(t.arguments)
|
||||
except json.JSONDecodeError:
|
||||
parsed_args = {"value": t.arguments}
|
||||
else:
|
||||
parsed_args = t.arguments
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"type": "function",
|
||||
"name": fname,
|
||||
"parameters": t.arguments,
|
||||
"parameters": parsed_args,
|
||||
}
|
||||
)
|
||||
elif tool_prompt_format == ToolPromptFormat.function_tag:
|
||||
args = json.dumps(t.arguments)
|
||||
if isinstance(t.arguments, str):
|
||||
# Try to parse string as JSON
|
||||
try:
|
||||
parsed_args = json.loads(t.arguments)
|
||||
args = json.dumps(parsed_args)
|
||||
except json.JSONDecodeError:
|
||||
args = json.dumps({"value": t.arguments})
|
||||
else:
|
||||
args = json.dumps(t.arguments)
|
||||
return f"<function={fname}>{args}</function>"
|
||||
|
||||
elif tool_prompt_format == ToolPromptFormat.python_list:
|
||||
|
|
@ -256,11 +305,22 @@ class ToolUtils:
|
|||
elif isinstance(value, list):
|
||||
return f"[{', '.join(format_value(v) for v in value)}]"
|
||||
elif isinstance(value, dict):
|
||||
return f"{{{', '.join(f'{k}={format_value(v)}' for k, v in value.items())}}}"
|
||||
return f"{{{', '.join(f'"{k}"={format_value(v)}' for k, v in value.items())}}}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported type: {type(value)}")
|
||||
|
||||
args_str = ", ".join(f"{k}={format_value(v)}" for k, v in t.arguments.items())
|
||||
if isinstance(t.arguments, str):
|
||||
# Try to parse string as JSON
|
||||
try:
|
||||
parsed_args = json.loads(t.arguments)
|
||||
if isinstance(parsed_args, dict):
|
||||
args_str = ", ".join(f"{k}={format_value(v)}" for k, v in parsed_args.items())
|
||||
else:
|
||||
args_str = f"value={format_value(parsed_args)}"
|
||||
except json.JSONDecodeError:
|
||||
args_str = f'value="{t.arguments}"'
|
||||
else:
|
||||
args_str = ", ".join(f"{k}={format_value(v)}" for k, v in t.arguments.items())
|
||||
return f"[{fname}({args_str})]"
|
||||
else:
|
||||
raise ValueError(f"Unsupported tool prompt format: {tool_prompt_format}")
|
||||
|
|
|
|||
|
|
@ -245,7 +245,7 @@ exclude = [
|
|||
"^llama_stack/models/llama/llama3/chat_format\\.py$",
|
||||
"^llama_stack/models/llama/llama3/interface\\.py$",
|
||||
"^llama_stack/models/llama/llama3/tokenizer\\.py$",
|
||||
"^llama_stack/models/llama/llama3/tool_utils\\.py$",
|
||||
"^llama_stack/models/llama/llama3_3/prompts\\.py$",
|
||||
"^llama_stack/providers/inline/agents/meta_reference/",
|
||||
"^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
|
||||
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue