Add tool prompt formats

This commit is contained in:
Hardik Shah 2024-08-13 16:00:47 -07:00
parent 0c3e754453
commit 48b78430eb
5 changed files with 32 additions and 10 deletions

View file

@ -110,6 +110,12 @@ class Session(BaseModel):
started_at: datetime
@json_schema_type
class ToolPromptFormat(Enum):
json = "json"
function_tag = "function_tag"
@json_schema_type
class AgenticSystemInstanceConfig(BaseModel):
instructions: str
@ -127,6 +133,9 @@ class AgenticSystemInstanceConfig(BaseModel):
# if you completely want to replace the messages prefixed by the system,
# this is debug only
debug_prefix_messages: Optional[List[Message]] = Field(default_factory=list)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.function_tag
)
class AgenticSystemTurnResponseEventType(Enum):

View file

@ -30,6 +30,7 @@ from .api import (
AgenticSystemToolDefinition,
AgenticSystemTurnCreateRequest,
AgenticSystemTurnResponseStreamChunk,
ToolPromptFormat,
)
@ -132,6 +133,7 @@ async def run_main(host: str, port: int):
output_shields=[],
quantization_config=None,
debug_prefix_messages=[],
tool_prompt_format=ToolPromptFormat.json,
),
)

View file

@ -10,6 +10,8 @@ import uuid
from datetime import datetime
from typing import AsyncGenerator, List, Optional
from termcolor import cprint
from llama_toolchain.agentic_system.api.datatypes import (
AgenticSystemInstanceConfig,
AgenticSystemTurnResponseEvent,
@ -24,6 +26,7 @@ from llama_toolchain.agentic_system.api.datatypes import (
ShieldCallStep,
StepType,
ToolExecutionStep,
ToolPromptFormat,
Turn,
)
@ -51,7 +54,6 @@ from llama_toolchain.safety.api.datatypes import (
ShieldDefinition,
ShieldResponse,
)
from termcolor import cprint
from llama_toolchain.agentic_system.api.endpoints import * # noqa
from .safety import SafetyException, ShieldRunnerMixin
@ -74,6 +76,7 @@ class AgentInstance(ShieldRunnerMixin):
output_shields: List[ShieldDefinition],
max_infer_iters: int = 10,
prefix_messages: Optional[List[Message]] = None,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.function_tag,
):
self.system_id = system_id
self.instance_config = instance_config
@ -86,7 +89,9 @@ class AgentInstance(ShieldRunnerMixin):
self.prefix_messages = prefix_messages
else:
self.prefix_messages = get_agentic_prefix_messages(
builtin_tools, custom_tool_definitions
builtin_tools,
custom_tool_definitions,
tool_prompt_format,
)
for m in self.prefix_messages:

View file

@ -108,6 +108,7 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
input_shields=cfg.input_shields,
output_shields=cfg.output_shields,
prefix_messages=cfg.debug_prefix_messages,
tool_prompt_format=cfg.tool_prompt_format,
)
return AgenticSystemCreateResponse(

View file

@ -8,6 +8,8 @@ import json
from datetime import datetime
from typing import List
from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat
from llama_toolchain.inference.api import (
BuiltinTool,
Message,
@ -19,7 +21,9 @@ from .tools.builtin import SingleMessageBuiltinTool
def get_agentic_prefix_messages(
builtin_tools: List[SingleMessageBuiltinTool], custom_tools: List[ToolDefinition]
builtin_tools: List[SingleMessageBuiltinTool],
custom_tools: List[ToolDefinition],
tool_prompt_format: ToolPromptFormat,
) -> List[Message]:
messages = []
content = ""
@ -44,14 +48,15 @@ Today Date: {formatted_date}\n"""
content += date_str
if custom_tools:
if tool_prompt_format == ToolPromptFormat.function_tag:
custom_message = get_system_prompt_for_custom_tools(custom_tools)
content += custom_message
# TODO: Replace this hard coded message with instructions coming in the request
if False:
content += "\nYou are a helpful Assistant."
messages.append(SystemMessage(content=content))
else:
raise NotImplementedError(
f"Tool prompt format {tool_prompt_format} is not supported"
)
return messages