simplify fmt check

This commit is contained in:
Dinesh Yeduguru 2025-01-10 10:40:03 -08:00
parent 5f69747b89
commit df2d86b0da

View file

@ -358,17 +358,13 @@ def augment_messages_for_tools_llama_3_1(
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools) has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
if has_custom_tools: if has_custom_tools:
if ( fmt = request.tool_prompt_format or ToolPromptFormat.json
request.tool_prompt_format == ToolPromptFormat.json if fmt == ToolPromptFormat.json:
or request.tool_prompt_format is None
):
tool_gen = JsonCustomToolGenerator() tool_gen = JsonCustomToolGenerator()
elif request.tool_prompt_format == ToolPromptFormat.function_tag: elif fmt == ToolPromptFormat.function_tag:
tool_gen = FunctionTagCustomToolGenerator() tool_gen = FunctionTagCustomToolGenerator()
else: else:
raise ValueError( raise ValueError(f"Non supported ToolPromptFormat {fmt}")
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
)
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)] custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
custom_template = tool_gen.gen(custom_tools) custom_template = tool_gen.gen(custom_tools)
@ -413,10 +409,8 @@ def augment_messages_for_tools_llama_3_2(
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)] custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
if custom_tools: if custom_tools:
if ( fmt = request.tool_prompt_format or ToolPromptFormat.python_list
request.tool_prompt_format is not None if fmt != ToolPromptFormat.python_list:
and request.tool_prompt_format != ToolPromptFormat.python_list
):
raise ValueError( raise ValueError(
f"Non supported ToolPromptFormat {request.tool_prompt_format}" f"Non supported ToolPromptFormat {request.tool_prompt_format}"
) )