diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index e322adf35..2d66dc60b 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -358,17 +358,13 @@ def augment_messages_for_tools_llama_3_1( has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools) if has_custom_tools: - if ( - request.tool_prompt_format == ToolPromptFormat.json - or request.tool_prompt_format is None - ): + fmt = request.tool_prompt_format or ToolPromptFormat.json + if fmt == ToolPromptFormat.json: tool_gen = JsonCustomToolGenerator() - elif request.tool_prompt_format == ToolPromptFormat.function_tag: + elif fmt == ToolPromptFormat.function_tag: tool_gen = FunctionTagCustomToolGenerator() else: - raise ValueError( - f"Non supported ToolPromptFormat {request.tool_prompt_format}" - ) + raise ValueError(f"Non supported ToolPromptFormat {fmt}") custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)] custom_template = tool_gen.gen(custom_tools) @@ -413,10 +409,8 @@ def augment_messages_for_tools_llama_3_2( custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)] if custom_tools: - if ( - request.tool_prompt_format is not None - and request.tool_prompt_format != ToolPromptFormat.python_list - ): + fmt = request.tool_prompt_format or ToolPromptFormat.python_list + if fmt != ToolPromptFormat.python_list: raise ValueError( f"Non supported ToolPromptFormat {request.tool_prompt_format}" )