Add tool prompt formats

This commit is contained in:
Hardik Shah 2024-08-13 16:00:47 -07:00
parent 0c3e754453
commit 48b78430eb
5 changed files with 32 additions and 10 deletions

View file

@ -110,6 +110,12 @@ class Session(BaseModel):
started_at: datetime started_at: datetime
@json_schema_type
class ToolPromptFormat(Enum):
json = "json"
function_tag = "function_tag"
@json_schema_type @json_schema_type
class AgenticSystemInstanceConfig(BaseModel): class AgenticSystemInstanceConfig(BaseModel):
instructions: str instructions: str
@ -127,6 +133,9 @@ class AgenticSystemInstanceConfig(BaseModel):
# if you completely want to replace the messages prefixed by the system, # if you completely want to replace the messages prefixed by the system,
# this is debug only # this is debug only
debug_prefix_messages: Optional[List[Message]] = Field(default_factory=list) debug_prefix_messages: Optional[List[Message]] = Field(default_factory=list)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.function_tag
)
class AgenticSystemTurnResponseEventType(Enum): class AgenticSystemTurnResponseEventType(Enum):

View file

@ -30,6 +30,7 @@ from .api import (
AgenticSystemToolDefinition, AgenticSystemToolDefinition,
AgenticSystemTurnCreateRequest, AgenticSystemTurnCreateRequest,
AgenticSystemTurnResponseStreamChunk, AgenticSystemTurnResponseStreamChunk,
ToolPromptFormat,
) )
@ -132,6 +133,7 @@ async def run_main(host: str, port: int):
output_shields=[], output_shields=[],
quantization_config=None, quantization_config=None,
debug_prefix_messages=[], debug_prefix_messages=[],
tool_prompt_format=ToolPromptFormat.json,
), ),
) )

View file

@ -10,6 +10,8 @@ import uuid
from datetime import datetime from datetime import datetime
from typing import AsyncGenerator, List, Optional from typing import AsyncGenerator, List, Optional
from termcolor import cprint
from llama_toolchain.agentic_system.api.datatypes import ( from llama_toolchain.agentic_system.api.datatypes import (
AgenticSystemInstanceConfig, AgenticSystemInstanceConfig,
AgenticSystemTurnResponseEvent, AgenticSystemTurnResponseEvent,
@ -24,6 +26,7 @@ from llama_toolchain.agentic_system.api.datatypes import (
ShieldCallStep, ShieldCallStep,
StepType, StepType,
ToolExecutionStep, ToolExecutionStep,
ToolPromptFormat,
Turn, Turn,
) )
@ -51,7 +54,6 @@ from llama_toolchain.safety.api.datatypes import (
ShieldDefinition, ShieldDefinition,
ShieldResponse, ShieldResponse,
) )
from termcolor import cprint
from llama_toolchain.agentic_system.api.endpoints import * # noqa from llama_toolchain.agentic_system.api.endpoints import * # noqa
from .safety import SafetyException, ShieldRunnerMixin from .safety import SafetyException, ShieldRunnerMixin
@ -74,6 +76,7 @@ class AgentInstance(ShieldRunnerMixin):
output_shields: List[ShieldDefinition], output_shields: List[ShieldDefinition],
max_infer_iters: int = 10, max_infer_iters: int = 10,
prefix_messages: Optional[List[Message]] = None, prefix_messages: Optional[List[Message]] = None,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.function_tag,
): ):
self.system_id = system_id self.system_id = system_id
self.instance_config = instance_config self.instance_config = instance_config
@ -86,7 +89,9 @@ class AgentInstance(ShieldRunnerMixin):
self.prefix_messages = prefix_messages self.prefix_messages = prefix_messages
else: else:
self.prefix_messages = get_agentic_prefix_messages( self.prefix_messages = get_agentic_prefix_messages(
builtin_tools, custom_tool_definitions builtin_tools,
custom_tool_definitions,
tool_prompt_format,
) )
for m in self.prefix_messages: for m in self.prefix_messages:

View file

@ -108,6 +108,7 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
input_shields=cfg.input_shields, input_shields=cfg.input_shields,
output_shields=cfg.output_shields, output_shields=cfg.output_shields,
prefix_messages=cfg.debug_prefix_messages, prefix_messages=cfg.debug_prefix_messages,
tool_prompt_format=cfg.tool_prompt_format,
) )
return AgenticSystemCreateResponse( return AgenticSystemCreateResponse(

View file

@ -8,6 +8,8 @@ import json
from datetime import datetime from datetime import datetime
from typing import List from typing import List
from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat
from llama_toolchain.inference.api import ( from llama_toolchain.inference.api import (
BuiltinTool, BuiltinTool,
Message, Message,
@ -19,7 +21,9 @@ from .tools.builtin import SingleMessageBuiltinTool
def get_agentic_prefix_messages( def get_agentic_prefix_messages(
builtin_tools: List[SingleMessageBuiltinTool], custom_tools: List[ToolDefinition] builtin_tools: List[SingleMessageBuiltinTool],
custom_tools: List[ToolDefinition],
tool_prompt_format: ToolPromptFormat,
) -> List[Message]: ) -> List[Message]:
messages = [] messages = []
content = "" content = ""
@ -44,14 +48,15 @@ Today Date: {formatted_date}\n"""
content += date_str content += date_str
if custom_tools: if custom_tools:
custom_message = get_system_prompt_for_custom_tools(custom_tools) if tool_prompt_format == ToolPromptFormat.function_tag:
content += custom_message custom_message = get_system_prompt_for_custom_tools(custom_tools)
content += custom_message
messages.append(SystemMessage(content=content))
else:
raise NotImplementedError(
f"Tool prompt format {tool_prompt_format} is not supported"
)
# TODO: Replace this hard coded message with instructions coming in the request
if False:
content += "\nYou are a helpful Assistant."
messages.append(SystemMessage(content=content))
return messages return messages