From 00053b5bb07ab4b056420236c27dbc21d02ced59 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Thu, 15 Aug 2024 12:11:35 -0700 Subject: [PATCH] function_tag system prompt is also added as a user message --- llama_toolchain/agentic_system/client.py | 30 ++++++++----------- .../meta_reference/system_prompt.py | 8 ++--- 2 files changed, 15 insertions(+), 23 deletions(-) diff --git a/llama_toolchain/agentic_system/client.py b/llama_toolchain/agentic_system/client.py index 3066255ea..5b8053af9 100644 --- a/llama_toolchain/agentic_system/client.py +++ b/llama_toolchain/agentic_system/client.py @@ -104,34 +104,25 @@ async def run_main(host: str, port: int): AgenticSystemToolDefinition( tool_name=BuiltinTool.wolfram_alpha, ), - AgenticSystemToolDefinition( - tool_name=BuiltinTool.photogen, - ), AgenticSystemToolDefinition( tool_name=BuiltinTool.code_interpreter, ), ] tool_definitions += [ AgenticSystemToolDefinition( - tool_name="custom_tool", - description="a custom tool", + tool_name="get_boiling_point", + description="Get the boiling point of a imaginary liquids (eg. polyjuice)", parameters={ - "param1": ToolParamDefinition( + "liquid_name": ToolParamDefinition( param_type="str", - description="a string parameter", + description="The name of the liquid", required=True, - ) - }, - ), - AgenticSystemToolDefinition( - tool_name="custom_tool_2", - description="a second custom tool", - parameters={ - "param2": ToolParamDefinition( + ), + "celcius": ToolParamDefinition( param_type="str", - description="a string parameter", - required=True, - ) + description="Whether to return the boiling point in Celcius", + required=False, + ), }, ), ] @@ -163,7 +154,10 @@ async def run_main(host: str, port: int): user_prompts = [ "Who are you?", + "what is the 100th prime number?", + "Search web for who was 44th President of USA?", "Write code to check if a number is prime. Use that to check if 7 is prime", + "What is the boiling point of polyjuicepotion ?", ] for content in user_prompts: cprint(f"User> {content}", color="blue") diff --git a/llama_toolchain/agentic_system/meta_reference/system_prompt.py b/llama_toolchain/agentic_system/meta_reference/system_prompt.py index d3b5b45d7..9db3218c1 100644 --- a/llama_toolchain/agentic_system/meta_reference/system_prompt.py +++ b/llama_toolchain/agentic_system/meta_reference/system_prompt.py @@ -48,15 +48,13 @@ def get_agentic_prefix_messages( Cutting Knowledge Date: December 2023 Today Date: {formatted_date}\n""" content += date_str + messages.append(SystemMessage(content=content)) if custom_tools: if tool_prompt_format == ToolPromptFormat.function_tag: - custom_message = prompt_for_function_tag(custom_tools) - content += custom_message - messages.append(SystemMessage(content=content)) + text = prompt_for_function_tag(custom_tools) + messages.append(UserMessage(content=text)) elif tool_prompt_format == ToolPromptFormat.json: - messages.append(SystemMessage(content=content)) - # json is added as a user prompt text = prompt_for_json(custom_tools) messages.append(UserMessage(content=text)) else: