Updates to prompt for tool calls (#29)

* update system prompts to drop new line

* Add tool prompt formats

* support json format

* JSON in caps

* function_tag system prompt is also added as a user message

* added docstrings for ToolPromptFormat

---------

Co-authored-by: Hardik Shah <hjshah@fb.com>
This commit is contained in:
Hardik Shah 2024-08-15 13:23:51 -07:00 committed by GitHub
parent 0d933ac4c5
commit b8fc4d4dee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 173 additions and 30 deletions

View file

@ -110,6 +110,35 @@ class Session(BaseModel):
started_at: datetime started_at: datetime
@json_schema_type
class ToolPromptFormat(Enum):
"""This Enum refers to the prompt format for calling zero shot tools
`json` --
Refers to the json format for calling tools.
The json format takes the form like
{
"type": "function",
"function" : {
"name": "function_name",
"description": "function_description",
"parameters": {...}
}
}
`function_tag` --
This is an example of how you could define
your own user defined format for making tool calls.
The function_tag format looks like this,
<function=function_name>(parameters)</function>
The detailed prompts for each of these formats are defined in `system_prompt.py`
"""
json = "json"
function_tag = "function_tag"
@json_schema_type @json_schema_type
class AgenticSystemInstanceConfig(BaseModel): class AgenticSystemInstanceConfig(BaseModel):
instructions: str instructions: str
@ -127,6 +156,9 @@ class AgenticSystemInstanceConfig(BaseModel):
# if you completely want to replace the messages prefixed by the system, # if you completely want to replace the messages prefixed by the system,
# this is debug only # this is debug only
debug_prefix_messages: Optional[List[Message]] = Field(default_factory=list) debug_prefix_messages: Optional[List[Message]] = Field(default_factory=list)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
class AgenticSystemTurnResponseEventType(Enum): class AgenticSystemTurnResponseEventType(Enum):

View file

@ -13,8 +13,15 @@ import fire
import httpx import httpx
from llama_models.llama3_1.api.datatypes import BuiltinTool, SamplingParams from llama_models.llama3_1.api.datatypes import (
BuiltinTool,
SamplingParams,
ToolParamDefinition,
UserMessage,
)
from termcolor import cprint
from llama_toolchain.agentic_system.event_logger import EventLogger
from .api import ( from .api import (
AgenticSystem, AgenticSystem,
AgenticSystemCreateRequest, AgenticSystemCreateRequest,
@ -25,6 +32,7 @@ from .api import (
AgenticSystemToolDefinition, AgenticSystemToolDefinition,
AgenticSystemTurnCreateRequest, AgenticSystemTurnCreateRequest,
AgenticSystemTurnResponseStreamChunk, AgenticSystemTurnResponseStreamChunk,
ToolPromptFormat,
) )
@ -87,7 +95,7 @@ class AgenticSystemClient(AgenticSystem):
async def run_main(host: str, port: int): async def run_main(host: str, port: int):
# client to test remote impl of agentic system # client to test remote impl of agentic system
api = await AgenticSystemClient(f"http://{host}:{port}") api = AgenticSystemClient(f"http://{host}:{port}")
tool_definitions = [ tool_definitions = [
AgenticSystemToolDefinition( AgenticSystemToolDefinition(
@ -96,13 +104,28 @@ async def run_main(host: str, port: int):
AgenticSystemToolDefinition( AgenticSystemToolDefinition(
tool_name=BuiltinTool.wolfram_alpha, tool_name=BuiltinTool.wolfram_alpha,
), ),
AgenticSystemToolDefinition(
tool_name=BuiltinTool.photogen,
),
AgenticSystemToolDefinition( AgenticSystemToolDefinition(
tool_name=BuiltinTool.code_interpreter, tool_name=BuiltinTool.code_interpreter,
), ),
] ]
tool_definitions += [
AgenticSystemToolDefinition(
tool_name="get_boiling_point",
description="Get the boiling point of a imaginary liquids (eg. polyjuice)",
parameters={
"liquid_name": ToolParamDefinition(
param_type="str",
description="The name of the liquid",
required=True,
),
"celcius": ToolParamDefinition(
param_type="str",
description="Whether to return the boiling point in Celcius",
required=False,
),
},
),
]
create_request = AgenticSystemCreateRequest( create_request = AgenticSystemCreateRequest(
model="Meta-Llama3.1-8B-Instruct", model="Meta-Llama3.1-8B-Instruct",
@ -114,12 +137,44 @@ async def run_main(host: str, port: int):
output_shields=[], output_shields=[],
quantization_config=None, quantization_config=None,
debug_prefix_messages=[], debug_prefix_messages=[],
tool_prompt_format=ToolPromptFormat.json,
), ),
) )
create_response = await api.create_agentic_system(create_request) create_response = await api.create_agentic_system(create_request)
print(create_response) print(create_response)
# TODO: Add chat session / turn apis to test e2e
session_response = await api.create_agentic_system_session(
AgenticSystemSessionCreateRequest(
system_id=create_response.system_id,
session_name="test_session",
)
)
print(session_response)
user_prompts = [
"Who are you?",
"what is the 100th prime number?",
"Search web for who was 44th President of USA?",
"Write code to check if a number is prime. Use that to check if 7 is prime",
"What is the boiling point of polyjuicepotion ?",
]
for content in user_prompts:
cprint(f"User> {content}", color="blue")
iterator = api.create_agentic_system_turn(
AgenticSystemTurnCreateRequest(
system_id=create_response.system_id,
session_id=session_response.session_id,
messages=[
UserMessage(content=content),
],
stream=True,
)
)
async for event, log in EventLogger().log(iterator):
if log is not None:
log.print()
def main(host: str, port: int): def main(host: str, port: int):

View file

@ -10,6 +10,8 @@ import uuid
from datetime import datetime from datetime import datetime
from typing import AsyncGenerator, List, Optional from typing import AsyncGenerator, List, Optional
from termcolor import cprint
from llama_toolchain.agentic_system.api.datatypes import ( from llama_toolchain.agentic_system.api.datatypes import (
AgenticSystemInstanceConfig, AgenticSystemInstanceConfig,
AgenticSystemTurnResponseEvent, AgenticSystemTurnResponseEvent,
@ -24,6 +26,7 @@ from llama_toolchain.agentic_system.api.datatypes import (
ShieldCallStep, ShieldCallStep,
StepType, StepType,
ToolExecutionStep, ToolExecutionStep,
ToolPromptFormat,
Turn, Turn,
) )
@ -51,7 +54,6 @@ from llama_toolchain.safety.api.datatypes import (
ShieldDefinition, ShieldDefinition,
ShieldResponse, ShieldResponse,
) )
from termcolor import cprint
from llama_toolchain.agentic_system.api.endpoints import * # noqa from llama_toolchain.agentic_system.api.endpoints import * # noqa
from .safety import SafetyException, ShieldRunnerMixin from .safety import SafetyException, ShieldRunnerMixin
@ -74,6 +76,7 @@ class AgentInstance(ShieldRunnerMixin):
output_shields: List[ShieldDefinition], output_shields: List[ShieldDefinition],
max_infer_iters: int = 10, max_infer_iters: int = 10,
prefix_messages: Optional[List[Message]] = None, prefix_messages: Optional[List[Message]] = None,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
): ):
self.system_id = system_id self.system_id = system_id
self.instance_config = instance_config self.instance_config = instance_config
@ -86,7 +89,9 @@ class AgentInstance(ShieldRunnerMixin):
self.prefix_messages = prefix_messages self.prefix_messages = prefix_messages
else: else:
self.prefix_messages = get_agentic_prefix_messages( self.prefix_messages = get_agentic_prefix_messages(
builtin_tools, custom_tool_definitions builtin_tools,
custom_tool_definitions,
tool_prompt_format,
) )
for m in self.prefix_messages: for m in self.prefix_messages:

View file

@ -108,6 +108,7 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
input_shields=cfg.input_shields, input_shields=cfg.input_shields,
output_shields=cfg.output_shields, output_shields=cfg.output_shields,
prefix_messages=cfg.debug_prefix_messages, prefix_messages=cfg.debug_prefix_messages,
tool_prompt_format=cfg.tool_prompt_format,
) )
return AgenticSystemCreateResponse( return AgenticSystemCreateResponse(

View file

@ -6,14 +6,15 @@
from typing import List from typing import List
from llama_models.llama3_1.api.datatypes import Message, Role from llama_models.llama3_1.api.datatypes import Message, Role, UserMessage
from termcolor import cprint
from llama_toolchain.safety.api.datatypes import ( from llama_toolchain.safety.api.datatypes import (
OnViolationAction, OnViolationAction,
ShieldDefinition, ShieldDefinition,
ShieldResponse, ShieldResponse,
) )
from llama_toolchain.safety.api.endpoints import RunShieldRequest, Safety from llama_toolchain.safety.api.endpoints import RunShieldRequest, Safety
from termcolor import cprint
class SafetyException(Exception): # noqa: N818 class SafetyException(Exception): # noqa: N818
@ -36,12 +37,11 @@ class ShieldRunnerMixin:
async def run_shields( async def run_shields(
self, messages: List[Message], shields: List[ShieldDefinition] self, messages: List[Message], shields: List[ShieldDefinition]
) -> List[ShieldResponse]: ) -> List[ShieldResponse]:
messages = messages.copy()
# some shields like llama-guard require the first message to be a user message # some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user # since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != Role.user.value: if len(messages) > 0 and messages[0].role != Role.user.value:
# TODO(ashwin): we need to change the type of the message, this kind of modification messages[0] = UserMessage(content=messages[0].content)
# is no longer appropriate
messages[0].role = Role.user.value
res = await self.safety_api.run_shields( res = await self.safety_api.run_shields(
RunShieldRequest( RunShieldRequest(

View file

@ -5,21 +5,27 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
import textwrap
from datetime import datetime from datetime import datetime
from typing import List from typing import List
from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat
from llama_toolchain.inference.api import ( from llama_toolchain.inference.api import (
BuiltinTool, BuiltinTool,
Message, Message,
SystemMessage, SystemMessage,
ToolDefinition, ToolDefinition,
UserMessage,
) )
from .tools.builtin import SingleMessageBuiltinTool from .tools.builtin import SingleMessageBuiltinTool
def get_agentic_prefix_messages( def get_agentic_prefix_messages(
builtin_tools: List[SingleMessageBuiltinTool], custom_tools: List[ToolDefinition] builtin_tools: List[SingleMessageBuiltinTool],
custom_tools: List[ToolDefinition],
tool_prompt_format: ToolPromptFormat,
) -> List[Message]: ) -> List[Message]:
messages = [] messages = []
content = "" content = ""
@ -34,28 +40,52 @@ def get_agentic_prefix_messages(
] ]
) )
if tool_str: if tool_str:
content += f"Tools: {tool_str}\n" content += f"Tools: {tool_str}"
current_date = datetime.now() current_date = datetime.now()
formatted_date = current_date.strftime("%d %B %Y") formatted_date = current_date.strftime("%d %B %Y")
date_str = f""" date_str = f"""
Cutting Knowledge Date: December 2023 Cutting Knowledge Date: December 2023
Today Date: {formatted_date}\n\n""" Today Date: {formatted_date}\n"""
content += date_str content += date_str
messages.append(SystemMessage(content=content))
if custom_tools: if custom_tools:
custom_message = get_system_prompt_for_custom_tools(custom_tools) if tool_prompt_format == ToolPromptFormat.function_tag:
content += custom_message text = prompt_for_function_tag(custom_tools)
messages.append(UserMessage(content=text))
elif tool_prompt_format == ToolPromptFormat.json:
text = prompt_for_json(custom_tools)
messages.append(UserMessage(content=text))
else:
raise NotImplementedError(
f"Tool prompt format {tool_prompt_format} is not supported"
)
else:
messages.append(SystemMessage(content=content))
# TODO: Replace this hard coded message with instructions coming in the request
if False:
content += "You are a helpful Assistant."
messages.append(SystemMessage(content=content))
return messages return messages
def get_system_prompt_for_custom_tools(custom_tools: List[ToolDefinition]) -> str: def prompt_for_json(custom_tools: List[ToolDefinition]) -> str:
tool_defs = "\n".join(
translate_custom_tool_definition_to_json(t) for t in custom_tools
)
content = textwrap.dedent(
"""
Answer the user's question by making use of the following functions if needed.
If none of the function can be used, please say so.
Here is a list of functions in JSON format:
{tool_defs}
Return function calls in JSON format.
"""
)
content = content.lstrip("\n").format(tool_defs=tool_defs)
return content
def prompt_for_function_tag(custom_tools: List[ToolDefinition]) -> str:
custom_tool_params = "" custom_tool_params = ""
for t in custom_tools: for t in custom_tools:
custom_tool_params += get_instruction_string(t) + "\n" custom_tool_params += get_instruction_string(t) + "\n"
@ -76,7 +106,6 @@ Reminder:
- Required parameters MUST be specified - Required parameters MUST be specified
- Only call one function at a time - Only call one function at a time
- Put the entire function call reply on one line - Put the entire function call reply on one line
""" """
return content return content
@ -98,7 +127,6 @@ def get_parameters_string(custom_tool_definition) -> str:
) )
# NOTE: Unused right now
def translate_custom_tool_definition_to_json(tool_def): def translate_custom_tool_definition_to_json(tool_def):
"""Translates ToolDefinition to json as expected by model """Translates ToolDefinition to json as expected by model
eg. output for a function eg. output for a function
@ -149,4 +177,4 @@ def translate_custom_tool_definition_to_json(tool_def):
else: else:
func_def["function"]["parameters"] = {} func_def["function"]["parameters"] = {}
return json.dumps(func_def) return json.dumps(func_def, indent=4)

View file

@ -15,6 +15,7 @@ from llama_toolchain.agentic_system.api import (
AgenticSystemSessionCreateRequest, AgenticSystemSessionCreateRequest,
AgenticSystemToolDefinition, AgenticSystemToolDefinition,
) )
from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat
from llama_toolchain.agentic_system.client import AgenticSystemClient from llama_toolchain.agentic_system.client import AgenticSystemClient
from llama_toolchain.agentic_system.tools.custom.execute import ( from llama_toolchain.agentic_system.tools.custom.execute import (
@ -64,6 +65,7 @@ async def get_agent_system_instance(
custom_tools: Optional[List[Any]] = None, custom_tools: Optional[List[Any]] = None,
disable_safety: bool = False, disable_safety: bool = False,
model: str = "Meta-Llama3.1-8B-Instruct", model: str = "Meta-Llama3.1-8B-Instruct",
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
) -> AgenticSystemClientWrapper: ) -> AgenticSystemClientWrapper:
custom_tools = custom_tools or [] custom_tools = custom_tools or []
@ -113,6 +115,7 @@ async def get_agent_system_instance(
] ]
), ),
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
tool_prompt_format=tool_prompt_format,
), ),
) )
create_response = await api.create_agentic_system(create_request) create_response = await api.create_agentic_system(create_request)

View file

@ -8,12 +8,11 @@ from enum import Enum
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from llama_models.llama3_1.api.datatypes import ToolParamDefinition from llama_models.llama3_1.api.datatypes import ToolParamDefinition
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig from pydantic import BaseModel, validator
from pydantic import BaseModel from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
@json_schema_type @json_schema_type
@ -43,6 +42,16 @@ class ShieldDefinition(BaseModel):
on_violation_action: OnViolationAction = OnViolationAction.RAISE on_violation_action: OnViolationAction = OnViolationAction.RAISE
execution_config: Optional[RestAPIExecutionConfig] = None execution_config: Optional[RestAPIExecutionConfig] = None
@validator("shield_type", pre=True)
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinShield(v)
except ValueError:
return v
return v
@json_schema_type @json_schema_type
class ShieldResponse(BaseModel): class ShieldResponse(BaseModel):
@ -51,3 +60,13 @@ class ShieldResponse(BaseModel):
is_violation: bool is_violation: bool
violation_type: Optional[str] = None violation_type: Optional[str] = None
violation_return_message: Optional[str] = None violation_return_message: Optional[str] = None
@validator("shield_type", pre=True)
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinShield(v)
except ValueError:
return v
return v