mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-26 09:15:40 +00:00 
			
		
		
		
	# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
		
			
				
	
	
		
			266 lines
		
	
	
	
		
			9.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			266 lines
		
	
	
	
		
			9.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import json
 | |
| import re
 | |
| 
 | |
| from llama_stack.log import get_logger
 | |
| 
 | |
| from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
 | |
| 
 | |
| logger = get_logger(name=__name__, category="inference")
 | |
| 
 | |
| BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
 | |
| CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
 | |
| 
 | |
| 
 | |
| def is_json(s):
 | |
|     try:
 | |
|         parsed = json.loads(s)
 | |
|         # Return True for valid objects and not for ints, strings, etc
 | |
|         return isinstance(parsed, dict)
 | |
|     except json.JSONDecodeError:
 | |
|         return False
 | |
|     return True
 | |
| 
 | |
| 
 | |
| def parse_llama_tool_call_format(input_string):
 | |
|     """
 | |
|     Parse tool calls in the format:
 | |
|     [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
 | |
| 
 | |
|     Returns a list of (function_name, arguments_dict) tuples or None if parsing fails.
 | |
|     """
 | |
|     # Strip outer brackets and whitespace
 | |
|     input_string = input_string.strip()
 | |
|     if not (input_string.startswith("[") and input_string.endswith("]")):
 | |
|         return None
 | |
| 
 | |
|     content = input_string[1:-1].strip()
 | |
|     if not content:
 | |
|         return None
 | |
| 
 | |
|     result = []
 | |
| 
 | |
|     # State variables for parsing
 | |
|     pos = 0
 | |
|     length = len(content)
 | |
| 
 | |
|     while pos < length:
 | |
|         # Find function name
 | |
|         name_end = content.find("(", pos)
 | |
|         if name_end == -1:
 | |
|             break
 | |
| 
 | |
|         func_name = content[pos:name_end].strip()
 | |
| 
 | |
|         # Find closing parenthesis for this function call
 | |
|         paren_level = 1
 | |
|         args_start = name_end + 1
 | |
|         args_end = args_start
 | |
| 
 | |
|         while args_end < length and paren_level > 0:
 | |
|             if content[args_end] == "(":
 | |
|                 paren_level += 1
 | |
|             elif content[args_end] == ")":
 | |
|                 paren_level -= 1
 | |
|             args_end += 1
 | |
| 
 | |
|         if paren_level != 0:
 | |
|             # Unmatched parentheses
 | |
|             return None
 | |
| 
 | |
|         # Parse arguments
 | |
|         args_str = content[args_start : args_end - 1].strip()
 | |
|         args_dict = {}
 | |
| 
 | |
|         if args_str:
 | |
|             # Split by commas, but respect nested structures
 | |
|             parts = []
 | |
|             part_start = 0
 | |
|             in_quotes = False
 | |
|             quote_char = None
 | |
|             nested_level = 0
 | |
| 
 | |
|             for i, char in enumerate(args_str):
 | |
|                 if char in ('"', "'") and (i == 0 or args_str[i - 1] != "\\"):
 | |
|                     if not in_quotes:
 | |
|                         in_quotes = True
 | |
|                         quote_char = char
 | |
|                     elif char == quote_char:
 | |
|                         in_quotes = False
 | |
|                         quote_char = None
 | |
|                 elif not in_quotes:
 | |
|                     if char in ("{", "["):
 | |
|                         nested_level += 1
 | |
|                     elif char in ("}", "]"):
 | |
|                         nested_level -= 1
 | |
|                     elif char == "," and nested_level == 0:
 | |
|                         parts.append(args_str[part_start:i].strip())
 | |
|                         part_start = i + 1
 | |
| 
 | |
|             parts.append(args_str[part_start:].strip())
 | |
| 
 | |
|             # Process each key=value pair
 | |
|             for part in parts:
 | |
|                 if "=" in part:
 | |
|                     key, value = part.split("=", 1)
 | |
|                     key = key.strip()
 | |
|                     value = value.strip()
 | |
| 
 | |
|                     # Try to convert value to appropriate Python type
 | |
|                     if (value.startswith('"') and value.endswith('"')) or (
 | |
|                         value.startswith("'") and value.endswith("'")
 | |
|                     ):
 | |
|                         # String
 | |
|                         value = value[1:-1]
 | |
|                     elif value.lower() == "true":
 | |
|                         value = True
 | |
|                     elif value.lower() == "false":
 | |
|                         value = False
 | |
|                     elif value.lower() == "none":
 | |
|                         value = None
 | |
|                     elif value.startswith("{") and value.endswith("}"):
 | |
|                         # This is a nested dictionary
 | |
|                         try:
 | |
|                             # Try to parse as JSON
 | |
|                             value = json.loads(value.replace("'", '"'))
 | |
|                         except json.JSONDecodeError:
 | |
|                             # Keep as string if parsing fails
 | |
|                             pass
 | |
|                     elif value.startswith("[") and value.endswith("]"):
 | |
|                         # This is a nested list
 | |
|                         try:
 | |
|                             # Try to parse as JSON
 | |
|                             value = json.loads(value.replace("'", '"'))
 | |
|                         except json.JSONDecodeError:
 | |
|                             # Keep as string if parsing fails
 | |
|                             pass
 | |
|                     else:
 | |
|                         # Try to convert to number
 | |
|                         try:
 | |
|                             if "." in value:
 | |
|                                 value = float(value)
 | |
|                             else:
 | |
|                                 value = int(value)
 | |
|                         except ValueError:
 | |
|                             # Keep as string if not a valid number
 | |
|                             pass
 | |
| 
 | |
|                     args_dict[key] = value
 | |
| 
 | |
|         result.append((func_name, args_dict))
 | |
| 
 | |
|         # Move to the next function call
 | |
|         pos = args_end
 | |
| 
 | |
|         # Skip the comma between function calls if present
 | |
|         if pos < length and content[pos] == ",":
 | |
|             pos += 1
 | |
| 
 | |
|     return result if result else None
 | |
| 
 | |
| 
 | |
| class ToolUtils:
 | |
|     @staticmethod
 | |
|     def is_builtin_tool_call(message_body: str) -> bool:
 | |
|         match = re.search(ToolUtils.BUILTIN_TOOL_PATTERN, message_body)
 | |
|         return match is not None
 | |
| 
 | |
|     @staticmethod
 | |
|     def maybe_extract_builtin_tool_call(message_body: str) -> tuple[str, str] | None:
 | |
|         # Find the first match in the text
 | |
|         match = re.search(BUILTIN_TOOL_PATTERN, message_body)
 | |
| 
 | |
|         # Check if a match is found and return it
 | |
|         if match:
 | |
|             tool_name = match.group("tool_name")
 | |
|             query = match.group("query")
 | |
|             return tool_name, query
 | |
|         else:
 | |
|             return None
 | |
| 
 | |
|     @staticmethod
 | |
|     def maybe_extract_custom_tool_call(message_body: str) -> tuple[str, str] | None:
 | |
|         # NOTE: Custom function too calls are still experimental
 | |
|         # Sometimes, response is of the form
 | |
|         # {"type": "function", "name": "function_name", "parameters": {...}
 | |
|         # and some times
 | |
|         # <function=function_name>(parameters)</function>
 | |
| 
 | |
|         # Find the first match in the text
 | |
|         match = re.search(CUSTOM_TOOL_CALL_PATTERN, message_body)
 | |
|         if match:
 | |
|             tool_name = match.group("function_name")
 | |
|             query = match.group("args")
 | |
|             try:
 | |
|                 return tool_name, json.loads(query.replace("'", '"'))
 | |
|             except Exception as e:
 | |
|                 print("Exception while parsing json query for custom tool call", query, e)
 | |
|                 return None
 | |
|         elif is_json(message_body):
 | |
|             response = json.loads(message_body)
 | |
|             if ("type" in response and response["type"] == "function") or (
 | |
|                 "name" in response and "parameters" in response
 | |
|             ):
 | |
|                 function_name = response["name"]
 | |
|                 args = response["parameters"]
 | |
|                 return function_name, args
 | |
|             else:
 | |
|                 return None
 | |
|         elif function_calls := parse_llama_tool_call_format(message_body):
 | |
|             # FIXME: Enable multiple tool calls
 | |
|             return function_calls[0]
 | |
|         else:
 | |
|             logger.debug(f"Did not parse tool call from message body: {message_body}")
 | |
|             return None
 | |
| 
 | |
|     @staticmethod
 | |
|     def encode_tool_call(t: ToolCall, tool_prompt_format: ToolPromptFormat) -> str:
 | |
|         if t.tool_name == BuiltinTool.brave_search:
 | |
|             q = t.arguments["query"]
 | |
|             return f'brave_search.call(query="{q}")'
 | |
|         elif t.tool_name == BuiltinTool.wolfram_alpha:
 | |
|             q = t.arguments["query"]
 | |
|             return f'wolfram_alpha.call(query="{q}")'
 | |
|         elif t.tool_name == BuiltinTool.photogen:
 | |
|             q = t.arguments["query"]
 | |
|             return f'photogen.call(query="{q}")'
 | |
|         elif t.tool_name == BuiltinTool.code_interpreter:
 | |
|             return t.arguments["code"]
 | |
|         else:
 | |
|             fname = t.tool_name
 | |
| 
 | |
|             if tool_prompt_format == ToolPromptFormat.json:
 | |
|                 return json.dumps(
 | |
|                     {
 | |
|                         "type": "function",
 | |
|                         "name": fname,
 | |
|                         "parameters": t.arguments,
 | |
|                     }
 | |
|                 )
 | |
|             elif tool_prompt_format == ToolPromptFormat.function_tag:
 | |
|                 args = json.dumps(t.arguments)
 | |
|                 return f"<function={fname}>{args}</function>"
 | |
| 
 | |
|             elif tool_prompt_format == ToolPromptFormat.python_list:
 | |
| 
 | |
|                 def format_value(value: RecursiveType) -> str:
 | |
|                     if isinstance(value, str):
 | |
|                         return f'"{value}"'
 | |
|                     elif isinstance(value, int | float | bool) or value is None:
 | |
|                         return str(value)
 | |
|                     elif isinstance(value, list):
 | |
|                         return f"[{', '.join(format_value(v) for v in value)}]"
 | |
|                     elif isinstance(value, dict):
 | |
|                         return f"{{{', '.join(f'{k}={format_value(v)}' for k, v in value.items())}}}"
 | |
|                     else:
 | |
|                         raise ValueError(f"Unsupported type: {type(value)}")
 | |
| 
 | |
|                 args_str = ", ".join(f"{k}={format_value(v)}" for k, v in t.arguments.items())
 | |
|                 return f"[{fname}({args_str})]"
 | |
|             else:
 | |
|                 raise ValueError(f"Unsupported tool prompt format: {tool_prompt_format}")
 |