mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
support json format
This commit is contained in:
parent
48b78430eb
commit
86df597a83
7 changed files with 97 additions and 29 deletions
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import textwrap
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
|
@ -15,6 +16,7 @@ from llama_toolchain.inference.api import (
|
|||
Message,
|
||||
SystemMessage,
|
||||
ToolDefinition,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from .tools.builtin import SingleMessageBuiltinTool
|
||||
|
@ -49,18 +51,43 @@ Today Date: {formatted_date}\n"""
|
|||
|
||||
if custom_tools:
|
||||
if tool_prompt_format == ToolPromptFormat.function_tag:
|
||||
custom_message = get_system_prompt_for_custom_tools(custom_tools)
|
||||
custom_message = prompt_for_function_tag(custom_tools)
|
||||
content += custom_message
|
||||
messages.append(SystemMessage(content=content))
|
||||
elif tool_prompt_format == ToolPromptFormat.json:
|
||||
messages.append(SystemMessage(content=content))
|
||||
# json is added as a user prompt
|
||||
text = prompt_for_json(custom_tools)
|
||||
messages.append(UserMessage(content=text))
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Tool prompt format {tool_prompt_format} is not supported"
|
||||
)
|
||||
else:
|
||||
messages.append(SystemMessage(content=content))
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def get_system_prompt_for_custom_tools(custom_tools: List[ToolDefinition]) -> str:
|
||||
def prompt_for_json(custom_tools: List[ToolDefinition]) -> str:
|
||||
tool_defs = "\n".join(
|
||||
translate_custom_tool_definition_to_json(t) for t in custom_tools
|
||||
)
|
||||
content = textwrap.dedent(
|
||||
"""
|
||||
Answer the user's question by making use of the following functions if needed.
|
||||
If none of the function can be used, please say so.
|
||||
Here is a list of functions in JSON format:
|
||||
{tool_defs}
|
||||
|
||||
Return function calls in json format.
|
||||
"""
|
||||
)
|
||||
content = content.lstrip("\n").format(tool_defs=tool_defs)
|
||||
return content
|
||||
|
||||
|
||||
def prompt_for_function_tag(custom_tools: List[ToolDefinition]) -> str:
|
||||
custom_tool_params = ""
|
||||
for t in custom_tools:
|
||||
custom_tool_params += get_instruction_string(t) + "\n"
|
||||
|
@ -102,7 +129,6 @@ def get_parameters_string(custom_tool_definition) -> str:
|
|||
)
|
||||
|
||||
|
||||
# NOTE: Unused right now
|
||||
def translate_custom_tool_definition_to_json(tool_def):
|
||||
"""Translates ToolDefinition to json as expected by model
|
||||
eg. output for a function
|
||||
|
@ -153,4 +179,4 @@ def translate_custom_tool_definition_to_json(tool_def):
|
|||
else:
|
||||
func_def["function"]["parameters"] = {}
|
||||
|
||||
return json.dumps(func_def)
|
||||
return json.dumps(func_def, indent=4)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue