Updates to prompt for tool calls (#29)

* update system prompts to drop new line

* Add tool prompt formats

* support json format

* JSON in caps

* function_tag system prompt is also added as a user message

* added docstrings for ToolPromptFormat

---------

Co-authored-by: Hardik Shah <hjshah@fb.com>
This commit is contained in:
Hardik Shah 2024-08-15 13:23:51 -07:00 committed by GitHub
parent 0d933ac4c5
commit b8fc4d4dee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 173 additions and 30 deletions

View file

@ -110,6 +110,35 @@ class Session(BaseModel):
started_at: datetime
@json_schema_type
class ToolPromptFormat(Enum):
"""This Enum refers to the prompt format for calling zero shot tools
`json` --
Refers to the json format for calling tools.
The json format takes the form like
{
"type": "function",
"function" : {
"name": "function_name",
"description": "function_description",
"parameters": {...}
}
}
`function_tag` --
This is an example of how you could define
your own user defined format for making tool calls.
The function_tag format looks like this,
<function=function_name>(parameters)</function>
The detailed prompts for each of these formats are defined in `system_prompt.py`
"""
json = "json"
function_tag = "function_tag"
@json_schema_type
class AgenticSystemInstanceConfig(BaseModel):
instructions: str
@ -127,6 +156,9 @@ class AgenticSystemInstanceConfig(BaseModel):
# if you completely want to replace the messages prefixed by the system,
# this is debug only
debug_prefix_messages: Optional[List[Message]] = Field(default_factory=list)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
class AgenticSystemTurnResponseEventType(Enum):

View file

@ -13,8 +13,15 @@ import fire
import httpx
from llama_models.llama3_1.api.datatypes import BuiltinTool, SamplingParams
from llama_models.llama3_1.api.datatypes import (
BuiltinTool,
SamplingParams,
ToolParamDefinition,
UserMessage,
)
from termcolor import cprint
from llama_toolchain.agentic_system.event_logger import EventLogger
from .api import (
AgenticSystem,
AgenticSystemCreateRequest,
@ -25,6 +32,7 @@ from .api import (
AgenticSystemToolDefinition,
AgenticSystemTurnCreateRequest,
AgenticSystemTurnResponseStreamChunk,
ToolPromptFormat,
)
@ -87,7 +95,7 @@ class AgenticSystemClient(AgenticSystem):
async def run_main(host: str, port: int):
# client to test remote impl of agentic system
api = await AgenticSystemClient(f"http://{host}:{port}")
api = AgenticSystemClient(f"http://{host}:{port}")
tool_definitions = [
AgenticSystemToolDefinition(
@ -96,13 +104,28 @@ async def run_main(host: str, port: int):
AgenticSystemToolDefinition(
tool_name=BuiltinTool.wolfram_alpha,
),
AgenticSystemToolDefinition(
tool_name=BuiltinTool.photogen,
),
AgenticSystemToolDefinition(
tool_name=BuiltinTool.code_interpreter,
),
]
tool_definitions += [
AgenticSystemToolDefinition(
tool_name="get_boiling_point",
description="Get the boiling point of a imaginary liquids (eg. polyjuice)",
parameters={
"liquid_name": ToolParamDefinition(
param_type="str",
description="The name of the liquid",
required=True,
),
"celcius": ToolParamDefinition(
param_type="str",
description="Whether to return the boiling point in Celcius",
required=False,
),
},
),
]
create_request = AgenticSystemCreateRequest(
model="Meta-Llama3.1-8B-Instruct",
@ -114,12 +137,44 @@ async def run_main(host: str, port: int):
output_shields=[],
quantization_config=None,
debug_prefix_messages=[],
tool_prompt_format=ToolPromptFormat.json,
),
)
create_response = await api.create_agentic_system(create_request)
print(create_response)
# TODO: Add chat session / turn apis to test e2e
session_response = await api.create_agentic_system_session(
AgenticSystemSessionCreateRequest(
system_id=create_response.system_id,
session_name="test_session",
)
)
print(session_response)
user_prompts = [
"Who are you?",
"what is the 100th prime number?",
"Search web for who was 44th President of USA?",
"Write code to check if a number is prime. Use that to check if 7 is prime",
"What is the boiling point of polyjuicepotion ?",
]
for content in user_prompts:
cprint(f"User> {content}", color="blue")
iterator = api.create_agentic_system_turn(
AgenticSystemTurnCreateRequest(
system_id=create_response.system_id,
session_id=session_response.session_id,
messages=[
UserMessage(content=content),
],
stream=True,
)
)
async for event, log in EventLogger().log(iterator):
if log is not None:
log.print()
def main(host: str, port: int):

View file

@ -10,6 +10,8 @@ import uuid
from datetime import datetime
from typing import AsyncGenerator, List, Optional
from termcolor import cprint
from llama_toolchain.agentic_system.api.datatypes import (
AgenticSystemInstanceConfig,
AgenticSystemTurnResponseEvent,
@ -24,6 +26,7 @@ from llama_toolchain.agentic_system.api.datatypes import (
ShieldCallStep,
StepType,
ToolExecutionStep,
ToolPromptFormat,
Turn,
)
@ -51,7 +54,6 @@ from llama_toolchain.safety.api.datatypes import (
ShieldDefinition,
ShieldResponse,
)
from termcolor import cprint
from llama_toolchain.agentic_system.api.endpoints import * # noqa
from .safety import SafetyException, ShieldRunnerMixin
@ -74,6 +76,7 @@ class AgentInstance(ShieldRunnerMixin):
output_shields: List[ShieldDefinition],
max_infer_iters: int = 10,
prefix_messages: Optional[List[Message]] = None,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
):
self.system_id = system_id
self.instance_config = instance_config
@ -86,7 +89,9 @@ class AgentInstance(ShieldRunnerMixin):
self.prefix_messages = prefix_messages
else:
self.prefix_messages = get_agentic_prefix_messages(
builtin_tools, custom_tool_definitions
builtin_tools,
custom_tool_definitions,
tool_prompt_format,
)
for m in self.prefix_messages:

View file

@ -108,6 +108,7 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
input_shields=cfg.input_shields,
output_shields=cfg.output_shields,
prefix_messages=cfg.debug_prefix_messages,
tool_prompt_format=cfg.tool_prompt_format,
)
return AgenticSystemCreateResponse(

View file

@ -6,14 +6,15 @@
from typing import List
from llama_models.llama3_1.api.datatypes import Message, Role
from llama_models.llama3_1.api.datatypes import Message, Role, UserMessage
from termcolor import cprint
from llama_toolchain.safety.api.datatypes import (
OnViolationAction,
ShieldDefinition,
ShieldResponse,
)
from llama_toolchain.safety.api.endpoints import RunShieldRequest, Safety
from termcolor import cprint
class SafetyException(Exception): # noqa: N818
@ -36,12 +37,11 @@ class ShieldRunnerMixin:
async def run_shields(
self, messages: List[Message], shields: List[ShieldDefinition]
) -> List[ShieldResponse]:
messages = messages.copy()
# some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != Role.user.value:
# TODO(ashwin): we need to change the type of the message, this kind of modification
# is no longer appropriate
messages[0].role = Role.user.value
messages[0] = UserMessage(content=messages[0].content)
res = await self.safety_api.run_shields(
RunShieldRequest(

View file

@ -5,21 +5,27 @@
# the root directory of this source tree.
import json
import textwrap
from datetime import datetime
from typing import List
from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat
from llama_toolchain.inference.api import (
BuiltinTool,
Message,
SystemMessage,
ToolDefinition,
UserMessage,
)
from .tools.builtin import SingleMessageBuiltinTool
def get_agentic_prefix_messages(
builtin_tools: List[SingleMessageBuiltinTool], custom_tools: List[ToolDefinition]
builtin_tools: List[SingleMessageBuiltinTool],
custom_tools: List[ToolDefinition],
tool_prompt_format: ToolPromptFormat,
) -> List[Message]:
messages = []
content = ""
@ -34,28 +40,52 @@ def get_agentic_prefix_messages(
]
)
if tool_str:
content += f"Tools: {tool_str}\n"
content += f"Tools: {tool_str}"
current_date = datetime.now()
formatted_date = current_date.strftime("%d %B %Y")
date_str = f"""
Cutting Knowledge Date: December 2023
Today Date: {formatted_date}\n\n"""
Today Date: {formatted_date}\n"""
content += date_str
messages.append(SystemMessage(content=content))
if custom_tools:
custom_message = get_system_prompt_for_custom_tools(custom_tools)
content += custom_message
# TODO: Replace this hard coded message with instructions coming in the request
if False:
content += "You are a helpful Assistant."
if tool_prompt_format == ToolPromptFormat.function_tag:
text = prompt_for_function_tag(custom_tools)
messages.append(UserMessage(content=text))
elif tool_prompt_format == ToolPromptFormat.json:
text = prompt_for_json(custom_tools)
messages.append(UserMessage(content=text))
else:
raise NotImplementedError(
f"Tool prompt format {tool_prompt_format} is not supported"
)
else:
messages.append(SystemMessage(content=content))
return messages
def get_system_prompt_for_custom_tools(custom_tools: List[ToolDefinition]) -> str:
def prompt_for_json(custom_tools: List[ToolDefinition]) -> str:
tool_defs = "\n".join(
translate_custom_tool_definition_to_json(t) for t in custom_tools
)
content = textwrap.dedent(
"""
Answer the user's question by making use of the following functions if needed.
If none of the function can be used, please say so.
Here is a list of functions in JSON format:
{tool_defs}
Return function calls in JSON format.
"""
)
content = content.lstrip("\n").format(tool_defs=tool_defs)
return content
def prompt_for_function_tag(custom_tools: List[ToolDefinition]) -> str:
custom_tool_params = ""
for t in custom_tools:
custom_tool_params += get_instruction_string(t) + "\n"
@ -76,7 +106,6 @@ Reminder:
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
"""
return content
@ -98,7 +127,6 @@ def get_parameters_string(custom_tool_definition) -> str:
)
# NOTE: Unused right now
def translate_custom_tool_definition_to_json(tool_def):
"""Translates ToolDefinition to json as expected by model
eg. output for a function
@ -149,4 +177,4 @@ def translate_custom_tool_definition_to_json(tool_def):
else:
func_def["function"]["parameters"] = {}
return json.dumps(func_def)
return json.dumps(func_def, indent=4)

View file

@ -15,6 +15,7 @@ from llama_toolchain.agentic_system.api import (
AgenticSystemSessionCreateRequest,
AgenticSystemToolDefinition,
)
from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat
from llama_toolchain.agentic_system.client import AgenticSystemClient
from llama_toolchain.agentic_system.tools.custom.execute import (
@ -64,6 +65,7 @@ async def get_agent_system_instance(
custom_tools: Optional[List[Any]] = None,
disable_safety: bool = False,
model: str = "Meta-Llama3.1-8B-Instruct",
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
) -> AgenticSystemClientWrapper:
custom_tools = custom_tools or []
@ -113,6 +115,7 @@ async def get_agent_system_instance(
]
),
sampling_params=SamplingParams(),
tool_prompt_format=tool_prompt_format,
),
)
create_response = await api.create_agentic_system(create_request)

View file

@ -8,12 +8,11 @@ from enum import Enum
from typing import Dict, Optional, Union
from llama_models.llama3_1.api.datatypes import ToolParamDefinition
from llama_models.schema_utils import json_schema_type
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
from pydantic import BaseModel, validator
from pydantic import BaseModel
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
@json_schema_type
@ -43,6 +42,16 @@ class ShieldDefinition(BaseModel):
on_violation_action: OnViolationAction = OnViolationAction.RAISE
execution_config: Optional[RestAPIExecutionConfig] = None
@validator("shield_type", pre=True)
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinShield(v)
except ValueError:
return v
return v
@json_schema_type
class ShieldResponse(BaseModel):
@ -51,3 +60,13 @@ class ShieldResponse(BaseModel):
is_violation: bool
violation_type: Optional[str] = None
violation_return_message: Optional[str] = None
@validator("shield_type", pre=True)
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinShield(v)
except ValueError:
return v
return v