mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
Add tool prompt formats
This commit is contained in:
parent
0c3e754453
commit
48b78430eb
5 changed files with 32 additions and 10 deletions
|
@ -110,6 +110,12 @@ class Session(BaseModel):
|
|||
started_at: datetime
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolPromptFormat(Enum):
|
||||
json = "json"
|
||||
function_tag = "function_tag"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemInstanceConfig(BaseModel):
|
||||
instructions: str
|
||||
|
@ -127,6 +133,9 @@ class AgenticSystemInstanceConfig(BaseModel):
|
|||
# if you completely want to replace the messages prefixed by the system,
|
||||
# this is debug only
|
||||
debug_prefix_messages: Optional[List[Message]] = Field(default_factory=list)
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
default=ToolPromptFormat.function_tag
|
||||
)
|
||||
|
||||
|
||||
class AgenticSystemTurnResponseEventType(Enum):
|
||||
|
|
|
@ -30,6 +30,7 @@ from .api import (
|
|||
AgenticSystemToolDefinition,
|
||||
AgenticSystemTurnCreateRequest,
|
||||
AgenticSystemTurnResponseStreamChunk,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
||||
|
||||
|
@ -132,6 +133,7 @@ async def run_main(host: str, port: int):
|
|||
output_shields=[],
|
||||
quantization_config=None,
|
||||
debug_prefix_messages=[],
|
||||
tool_prompt_format=ToolPromptFormat.json,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
@ -10,6 +10,8 @@ import uuid
|
|||
from datetime import datetime
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.agentic_system.api.datatypes import (
|
||||
AgenticSystemInstanceConfig,
|
||||
AgenticSystemTurnResponseEvent,
|
||||
|
@ -24,6 +26,7 @@ from llama_toolchain.agentic_system.api.datatypes import (
|
|||
ShieldCallStep,
|
||||
StepType,
|
||||
ToolExecutionStep,
|
||||
ToolPromptFormat,
|
||||
Turn,
|
||||
)
|
||||
|
||||
|
@ -51,7 +54,6 @@ from llama_toolchain.safety.api.datatypes import (
|
|||
ShieldDefinition,
|
||||
ShieldResponse,
|
||||
)
|
||||
from termcolor import cprint
|
||||
from llama_toolchain.agentic_system.api.endpoints import * # noqa
|
||||
|
||||
from .safety import SafetyException, ShieldRunnerMixin
|
||||
|
@ -74,6 +76,7 @@ class AgentInstance(ShieldRunnerMixin):
|
|||
output_shields: List[ShieldDefinition],
|
||||
max_infer_iters: int = 10,
|
||||
prefix_messages: Optional[List[Message]] = None,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.function_tag,
|
||||
):
|
||||
self.system_id = system_id
|
||||
self.instance_config = instance_config
|
||||
|
@ -86,7 +89,9 @@ class AgentInstance(ShieldRunnerMixin):
|
|||
self.prefix_messages = prefix_messages
|
||||
else:
|
||||
self.prefix_messages = get_agentic_prefix_messages(
|
||||
builtin_tools, custom_tool_definitions
|
||||
builtin_tools,
|
||||
custom_tool_definitions,
|
||||
tool_prompt_format,
|
||||
)
|
||||
|
||||
for m in self.prefix_messages:
|
||||
|
|
|
@ -108,6 +108,7 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
|||
input_shields=cfg.input_shields,
|
||||
output_shields=cfg.output_shields,
|
||||
prefix_messages=cfg.debug_prefix_messages,
|
||||
tool_prompt_format=cfg.tool_prompt_format,
|
||||
)
|
||||
|
||||
return AgenticSystemCreateResponse(
|
||||
|
|
|
@ -8,6 +8,8 @@ import json
|
|||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat
|
||||
|
||||
from llama_toolchain.inference.api import (
|
||||
BuiltinTool,
|
||||
Message,
|
||||
|
@ -19,7 +21,9 @@ from .tools.builtin import SingleMessageBuiltinTool
|
|||
|
||||
|
||||
def get_agentic_prefix_messages(
|
||||
builtin_tools: List[SingleMessageBuiltinTool], custom_tools: List[ToolDefinition]
|
||||
builtin_tools: List[SingleMessageBuiltinTool],
|
||||
custom_tools: List[ToolDefinition],
|
||||
tool_prompt_format: ToolPromptFormat,
|
||||
) -> List[Message]:
|
||||
messages = []
|
||||
content = ""
|
||||
|
@ -44,14 +48,15 @@ Today Date: {formatted_date}\n"""
|
|||
content += date_str
|
||||
|
||||
if custom_tools:
|
||||
custom_message = get_system_prompt_for_custom_tools(custom_tools)
|
||||
content += custom_message
|
||||
if tool_prompt_format == ToolPromptFormat.function_tag:
|
||||
custom_message = get_system_prompt_for_custom_tools(custom_tools)
|
||||
content += custom_message
|
||||
messages.append(SystemMessage(content=content))
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Tool prompt format {tool_prompt_format} is not supported"
|
||||
)
|
||||
|
||||
# TODO: Replace this hard coded message with instructions coming in the request
|
||||
if False:
|
||||
content += "\nYou are a helpful Assistant."
|
||||
|
||||
messages.append(SystemMessage(content=content))
|
||||
return messages
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue