mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 22:42:25 +00:00
chore(api): add mypy coverage to tools_utils
Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
51b179e1c5
commit
5c1671ac8c
2 changed files with 77 additions and 17 deletions
|
|
@ -167,7 +167,7 @@ def parse_llama_tool_call_format(input_string):
|
||||||
class ToolUtils:
|
class ToolUtils:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_builtin_tool_call(message_body: str) -> bool:
|
def is_builtin_tool_call(message_body: str) -> bool:
|
||||||
match = re.search(ToolUtils.BUILTIN_TOOL_PATTERN, message_body)
|
match = re.search(BUILTIN_TOOL_PATTERN, message_body)
|
||||||
return match is not None
|
return match is not None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -184,7 +184,7 @@ class ToolUtils:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def maybe_extract_custom_tool_call(message_body: str) -> tuple[str, str] | None:
|
def maybe_extract_custom_tool_call(message_body: str) -> tuple[str, dict] | None:
|
||||||
# NOTE: Custom function too calls are still experimental
|
# NOTE: Custom function too calls are still experimental
|
||||||
# Sometimes, response is of the form
|
# Sometimes, response is of the form
|
||||||
# {"type": "function", "name": "function_name", "parameters": {...}
|
# {"type": "function", "name": "function_name", "parameters": {...}
|
||||||
|
|
@ -197,7 +197,11 @@ class ToolUtils:
|
||||||
tool_name = match.group("function_name")
|
tool_name = match.group("function_name")
|
||||||
query = match.group("args")
|
query = match.group("args")
|
||||||
try:
|
try:
|
||||||
return tool_name, json.loads(query.replace("'", '"'))
|
parsed_args = json.loads(query.replace("'", '"'))
|
||||||
|
if isinstance(parsed_args, dict):
|
||||||
|
return tool_name, parsed_args
|
||||||
|
else:
|
||||||
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Exception while parsing json query for custom tool call", query, e)
|
print("Exception while parsing json query for custom tool call", query, e)
|
||||||
return None
|
return None
|
||||||
|
|
@ -208,12 +212,22 @@ class ToolUtils:
|
||||||
):
|
):
|
||||||
function_name = response["name"]
|
function_name = response["name"]
|
||||||
args = response["parameters"]
|
args = response["parameters"]
|
||||||
return function_name, args
|
if isinstance(args, dict):
|
||||||
|
return function_name, args
|
||||||
|
else:
|
||||||
|
return None
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
elif function_calls := parse_llama_tool_call_format(message_body):
|
elif function_calls := parse_llama_tool_call_format(message_body):
|
||||||
# FIXME: Enable multiple tool calls
|
# FIXME: Enable multiple tool calls
|
||||||
return function_calls[0]
|
if function_calls and len(function_calls) > 0:
|
||||||
|
func_name, args_dict = function_calls[0]
|
||||||
|
if isinstance(args_dict, dict):
|
||||||
|
return func_name, args_dict
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return None
|
||||||
else:
|
else:
|
||||||
logger.debug(f"Did not parse tool call from message body: {message_body}")
|
logger.debug(f"Did not parse tool call from message body: {message_body}")
|
||||||
return None
|
return None
|
||||||
|
|
@ -221,29 +235,64 @@ class ToolUtils:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def encode_tool_call(t: ToolCall, tool_prompt_format: ToolPromptFormat) -> str:
|
def encode_tool_call(t: ToolCall, tool_prompt_format: ToolPromptFormat) -> str:
|
||||||
if t.tool_name == BuiltinTool.brave_search:
|
if t.tool_name == BuiltinTool.brave_search:
|
||||||
q = t.arguments["query"]
|
if isinstance(t.arguments, dict):
|
||||||
return f'brave_search.call(query="{q}")'
|
q = t.arguments["query"]
|
||||||
|
return f'brave_search.call(query="{q}")'
|
||||||
|
else:
|
||||||
|
# Handle string arguments
|
||||||
|
return f'brave_search.call(query="{t.arguments}")'
|
||||||
elif t.tool_name == BuiltinTool.wolfram_alpha:
|
elif t.tool_name == BuiltinTool.wolfram_alpha:
|
||||||
q = t.arguments["query"]
|
if isinstance(t.arguments, dict):
|
||||||
return f'wolfram_alpha.call(query="{q}")'
|
q = t.arguments["query"]
|
||||||
|
return f'wolfram_alpha.call(query="{q}")'
|
||||||
|
else:
|
||||||
|
# Handle string arguments
|
||||||
|
return f'wolfram_alpha.call(query="{t.arguments}")'
|
||||||
elif t.tool_name == BuiltinTool.photogen:
|
elif t.tool_name == BuiltinTool.photogen:
|
||||||
q = t.arguments["query"]
|
if isinstance(t.arguments, dict):
|
||||||
return f'photogen.call(query="{q}")'
|
q = t.arguments["query"]
|
||||||
|
return f'photogen.call(query="{q}")'
|
||||||
|
else:
|
||||||
|
# Handle string arguments
|
||||||
|
return f'photogen.call(query="{t.arguments}")'
|
||||||
elif t.tool_name == BuiltinTool.code_interpreter:
|
elif t.tool_name == BuiltinTool.code_interpreter:
|
||||||
return t.arguments["code"]
|
if isinstance(t.arguments, dict):
|
||||||
|
code = t.arguments["code"]
|
||||||
|
return str(code)
|
||||||
|
else:
|
||||||
|
# Handle string arguments
|
||||||
|
return str(t.arguments)
|
||||||
else:
|
else:
|
||||||
fname = t.tool_name
|
fname = t.tool_name
|
||||||
|
|
||||||
if tool_prompt_format == ToolPromptFormat.json:
|
if tool_prompt_format == ToolPromptFormat.json:
|
||||||
|
# For JSON format, we need to handle both string and dict arguments
|
||||||
|
if isinstance(t.arguments, str):
|
||||||
|
# Try to parse string as JSON
|
||||||
|
try:
|
||||||
|
parsed_args = json.loads(t.arguments)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
parsed_args = {"value": t.arguments}
|
||||||
|
else:
|
||||||
|
parsed_args = t.arguments
|
||||||
|
|
||||||
return json.dumps(
|
return json.dumps(
|
||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"name": fname,
|
"name": fname,
|
||||||
"parameters": t.arguments,
|
"parameters": parsed_args,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
elif tool_prompt_format == ToolPromptFormat.function_tag:
|
elif tool_prompt_format == ToolPromptFormat.function_tag:
|
||||||
args = json.dumps(t.arguments)
|
if isinstance(t.arguments, str):
|
||||||
|
# Try to parse string as JSON
|
||||||
|
try:
|
||||||
|
parsed_args = json.loads(t.arguments)
|
||||||
|
args = json.dumps(parsed_args)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
args = json.dumps({"value": t.arguments})
|
||||||
|
else:
|
||||||
|
args = json.dumps(t.arguments)
|
||||||
return f"<function={fname}>{args}</function>"
|
return f"<function={fname}>{args}</function>"
|
||||||
|
|
||||||
elif tool_prompt_format == ToolPromptFormat.python_list:
|
elif tool_prompt_format == ToolPromptFormat.python_list:
|
||||||
|
|
@ -256,11 +305,22 @@ class ToolUtils:
|
||||||
elif isinstance(value, list):
|
elif isinstance(value, list):
|
||||||
return f"[{', '.join(format_value(v) for v in value)}]"
|
return f"[{', '.join(format_value(v) for v in value)}]"
|
||||||
elif isinstance(value, dict):
|
elif isinstance(value, dict):
|
||||||
return f"{{{', '.join(f'{k}={format_value(v)}' for k, v in value.items())}}}"
|
return f"{{{', '.join(f'"{k}"={format_value(v)}' for k, v in value.items())}}}"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported type: {type(value)}")
|
raise ValueError(f"Unsupported type: {type(value)}")
|
||||||
|
|
||||||
args_str = ", ".join(f"{k}={format_value(v)}" for k, v in t.arguments.items())
|
if isinstance(t.arguments, str):
|
||||||
|
# Try to parse string as JSON
|
||||||
|
try:
|
||||||
|
parsed_args = json.loads(t.arguments)
|
||||||
|
if isinstance(parsed_args, dict):
|
||||||
|
args_str = ", ".join(f"{k}={format_value(v)}" for k, v in parsed_args.items())
|
||||||
|
else:
|
||||||
|
args_str = f"value={format_value(parsed_args)}"
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
args_str = f'value="{t.arguments}"'
|
||||||
|
else:
|
||||||
|
args_str = ", ".join(f"{k}={format_value(v)}" for k, v in t.arguments.items())
|
||||||
return f"[{fname}({args_str})]"
|
return f"[{fname}({args_str})]"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported tool prompt format: {tool_prompt_format}")
|
raise ValueError(f"Unsupported tool prompt format: {tool_prompt_format}")
|
||||||
|
|
|
||||||
|
|
@ -245,7 +245,7 @@ exclude = [
|
||||||
"^llama_stack/models/llama/llama3/chat_format\\.py$",
|
"^llama_stack/models/llama/llama3/chat_format\\.py$",
|
||||||
"^llama_stack/models/llama/llama3/interface\\.py$",
|
"^llama_stack/models/llama/llama3/interface\\.py$",
|
||||||
"^llama_stack/models/llama/llama3/tokenizer\\.py$",
|
"^llama_stack/models/llama/llama3/tokenizer\\.py$",
|
||||||
"^llama_stack/models/llama/llama3/tool_utils\\.py$",
|
"^llama_stack/models/llama/llama3_3/prompts\\.py$",
|
||||||
"^llama_stack/providers/inline/agents/meta_reference/",
|
"^llama_stack/providers/inline/agents/meta_reference/",
|
||||||
"^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
|
"^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
|
||||||
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
|
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue