support json format

This commit is contained in:
Hardik Shah 2024-08-14 12:43:43 -07:00
parent 48b78430eb
commit 86df597a83
7 changed files with 97 additions and 29 deletions

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import json
import textwrap
from datetime import datetime
from typing import List
@ -15,6 +16,7 @@ from llama_toolchain.inference.api import (
Message,
SystemMessage,
ToolDefinition,
UserMessage,
)
from .tools.builtin import SingleMessageBuiltinTool
@ -49,18 +51,43 @@ Today Date: {formatted_date}\n"""
if custom_tools:
if tool_prompt_format == ToolPromptFormat.function_tag:
custom_message = get_system_prompt_for_custom_tools(custom_tools)
custom_message = prompt_for_function_tag(custom_tools)
content += custom_message
messages.append(SystemMessage(content=content))
elif tool_prompt_format == ToolPromptFormat.json:
messages.append(SystemMessage(content=content))
# json is added as a user prompt
text = prompt_for_json(custom_tools)
messages.append(UserMessage(content=text))
else:
raise NotImplementedError(
f"Tool prompt format {tool_prompt_format} is not supported"
)
else:
messages.append(SystemMessage(content=content))
return messages
def get_system_prompt_for_custom_tools(custom_tools: List[ToolDefinition]) -> str:
def prompt_for_json(custom_tools: List[ToolDefinition]) -> str:
tool_defs = "\n".join(
translate_custom_tool_definition_to_json(t) for t in custom_tools
)
content = textwrap.dedent(
"""
Answer the user's question by making use of the following functions if needed.
If none of the function can be used, please say so.
Here is a list of functions in JSON format:
{tool_defs}
Return function calls in json format.
"""
)
content = content.lstrip("\n").format(tool_defs=tool_defs)
return content
def prompt_for_function_tag(custom_tools: List[ToolDefinition]) -> str:
custom_tool_params = ""
for t in custom_tools:
custom_tool_params += get_instruction_string(t) + "\n"
@ -102,7 +129,6 @@ def get_parameters_string(custom_tool_definition) -> str:
)
# NOTE: Unused right now
def translate_custom_tool_definition_to_json(tool_def):
"""Translates ToolDefinition to json as expected by model
eg. output for a function
@ -153,4 +179,4 @@ def translate_custom_tool_definition_to_json(tool_def):
else:
func_def["function"]["parameters"] = {}
return json.dumps(func_def)
return json.dumps(func_def, indent=4)