diff --git a/llama_toolchain/agentic_system/api/datatypes.py b/llama_toolchain/agentic_system/api/datatypes.py index 1dda64834..825859b63 100644 --- a/llama_toolchain/agentic_system/api/datatypes.py +++ b/llama_toolchain/agentic_system/api/datatypes.py @@ -110,6 +110,12 @@ class Session(BaseModel): started_at: datetime +@json_schema_type +class ToolPromptFormat(Enum): + json = "json" + function_tag = "function_tag" + + @json_schema_type class AgenticSystemInstanceConfig(BaseModel): instructions: str @@ -127,6 +133,9 @@ class AgenticSystemInstanceConfig(BaseModel): # if you completely want to replace the messages prefixed by the system, # this is debug only debug_prefix_messages: Optional[List[Message]] = Field(default_factory=list) + tool_prompt_format: Optional[ToolPromptFormat] = Field( + default=ToolPromptFormat.function_tag + ) class AgenticSystemTurnResponseEventType(Enum): diff --git a/llama_toolchain/agentic_system/client.py b/llama_toolchain/agentic_system/client.py index 597872e30..6b6dc6106 100644 --- a/llama_toolchain/agentic_system/client.py +++ b/llama_toolchain/agentic_system/client.py @@ -30,6 +30,7 @@ from .api import ( AgenticSystemToolDefinition, AgenticSystemTurnCreateRequest, AgenticSystemTurnResponseStreamChunk, + ToolPromptFormat, ) @@ -132,6 +133,7 @@ async def run_main(host: str, port: int): output_shields=[], quantization_config=None, debug_prefix_messages=[], + tool_prompt_format=ToolPromptFormat.json, ), ) diff --git a/llama_toolchain/agentic_system/meta_reference/agent_instance.py b/llama_toolchain/agentic_system/meta_reference/agent_instance.py index 8e4555cb4..69784504f 100644 --- a/llama_toolchain/agentic_system/meta_reference/agent_instance.py +++ b/llama_toolchain/agentic_system/meta_reference/agent_instance.py @@ -10,6 +10,8 @@ import uuid from datetime import datetime from typing import AsyncGenerator, List, Optional +from termcolor import cprint + from llama_toolchain.agentic_system.api.datatypes import ( AgenticSystemInstanceConfig, AgenticSystemTurnResponseEvent, @@ -24,6 +26,7 @@ from llama_toolchain.agentic_system.api.datatypes import ( ShieldCallStep, StepType, ToolExecutionStep, + ToolPromptFormat, Turn, ) @@ -51,7 +54,6 @@ from llama_toolchain.safety.api.datatypes import ( ShieldDefinition, ShieldResponse, ) -from termcolor import cprint from llama_toolchain.agentic_system.api.endpoints import * # noqa from .safety import SafetyException, ShieldRunnerMixin @@ -74,6 +76,7 @@ class AgentInstance(ShieldRunnerMixin): output_shields: List[ShieldDefinition], max_infer_iters: int = 10, prefix_messages: Optional[List[Message]] = None, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.function_tag, ): self.system_id = system_id self.instance_config = instance_config @@ -86,7 +89,9 @@ class AgentInstance(ShieldRunnerMixin): self.prefix_messages = prefix_messages else: self.prefix_messages = get_agentic_prefix_messages( - builtin_tools, custom_tool_definitions + builtin_tools, + custom_tool_definitions, + tool_prompt_format, ) for m in self.prefix_messages: diff --git a/llama_toolchain/agentic_system/meta_reference/agentic_system.py b/llama_toolchain/agentic_system/meta_reference/agentic_system.py index 5db8d6168..ae1d282aa 100644 --- a/llama_toolchain/agentic_system/meta_reference/agentic_system.py +++ b/llama_toolchain/agentic_system/meta_reference/agentic_system.py @@ -108,6 +108,7 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem): input_shields=cfg.input_shields, output_shields=cfg.output_shields, prefix_messages=cfg.debug_prefix_messages, + tool_prompt_format=cfg.tool_prompt_format, ) return AgenticSystemCreateResponse( diff --git a/llama_toolchain/agentic_system/meta_reference/system_prompt.py b/llama_toolchain/agentic_system/meta_reference/system_prompt.py index d51e53a82..a6fa8e638 100644 --- a/llama_toolchain/agentic_system/meta_reference/system_prompt.py +++ b/llama_toolchain/agentic_system/meta_reference/system_prompt.py @@ -8,6 +8,8 @@ import json from datetime import datetime from typing import List +from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat + from llama_toolchain.inference.api import ( BuiltinTool, Message, @@ -19,7 +21,9 @@ from .tools.builtin import SingleMessageBuiltinTool def get_agentic_prefix_messages( - builtin_tools: List[SingleMessageBuiltinTool], custom_tools: List[ToolDefinition] + builtin_tools: List[SingleMessageBuiltinTool], + custom_tools: List[ToolDefinition], + tool_prompt_format: ToolPromptFormat, ) -> List[Message]: messages = [] content = "" @@ -44,14 +48,15 @@ Today Date: {formatted_date}\n""" content += date_str if custom_tools: - custom_message = get_system_prompt_for_custom_tools(custom_tools) - content += custom_message + if tool_prompt_format == ToolPromptFormat.function_tag: + custom_message = get_system_prompt_for_custom_tools(custom_tools) + content += custom_message + messages.append(SystemMessage(content=content)) + else: + raise NotImplementedError( + f"Tool prompt format {tool_prompt_format} is not supported" + ) - # TODO: Replace this hard coded message with instructions coming in the request - if False: - content += "\nYou are a helpful Assistant." - - messages.append(SystemMessage(content=content)) return messages