simplify fmt check

This commit is contained in:
Dinesh Yeduguru 2025-01-10 10:40:03 -08:00
parent 5f69747b89
commit df2d86b0da

View file

@ -358,17 +358,13 @@ def augment_messages_for_tools_llama_3_1(
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
if has_custom_tools:
if (
request.tool_prompt_format == ToolPromptFormat.json
or request.tool_prompt_format is None
):
fmt = request.tool_prompt_format or ToolPromptFormat.json
if fmt == ToolPromptFormat.json:
tool_gen = JsonCustomToolGenerator()
elif request.tool_prompt_format == ToolPromptFormat.function_tag:
elif fmt == ToolPromptFormat.function_tag:
tool_gen = FunctionTagCustomToolGenerator()
else:
raise ValueError(
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
)
raise ValueError(f"Non supported ToolPromptFormat {fmt}")
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
custom_template = tool_gen.gen(custom_tools)
@ -413,10 +409,8 @@ def augment_messages_for_tools_llama_3_2(
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
if custom_tools:
if (
request.tool_prompt_format is not None
and request.tool_prompt_format != ToolPromptFormat.python_list
):
fmt = request.tool_prompt_format or ToolPromptFormat.python_list
if fmt != ToolPromptFormat.python_list:
raise ValueError(
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
)