diff --git a/llama_toolchain/agentic_system/api/datatypes.py b/llama_toolchain/agentic_system/api/datatypes.py index 1dda64834..db4e40c4b 100644 --- a/llama_toolchain/agentic_system/api/datatypes.py +++ b/llama_toolchain/agentic_system/api/datatypes.py @@ -110,6 +110,35 @@ class Session(BaseModel): started_at: datetime +@json_schema_type +class ToolPromptFormat(Enum): + """This Enum refers to the prompt format for calling zero shot tools + + `json` -- + Refers to the json format for calling tools. + The json format takes the form like + { + "type": "function", + "function" : { + "name": "function_name", + "description": "function_description", + "parameters": {...} + } + } + + `function_tag` -- + This is an example of how you could define + your own user defined format for making tool calls. + The function_tag format looks like this, + (parameters) + + The detailed prompts for each of these formats are defined in `system_prompt.py` + """ + + json = "json" + function_tag = "function_tag" + + @json_schema_type class AgenticSystemInstanceConfig(BaseModel): instructions: str @@ -127,6 +156,9 @@ class AgenticSystemInstanceConfig(BaseModel): # if you completely want to replace the messages prefixed by the system, # this is debug only debug_prefix_messages: Optional[List[Message]] = Field(default_factory=list) + tool_prompt_format: Optional[ToolPromptFormat] = Field( + default=ToolPromptFormat.json + ) class AgenticSystemTurnResponseEventType(Enum): diff --git a/llama_toolchain/agentic_system/client.py b/llama_toolchain/agentic_system/client.py index 71c578e2f..5b8053af9 100644 --- a/llama_toolchain/agentic_system/client.py +++ b/llama_toolchain/agentic_system/client.py @@ -13,8 +13,15 @@ import fire import httpx -from llama_models.llama3_1.api.datatypes import BuiltinTool, SamplingParams +from llama_models.llama3_1.api.datatypes import ( + BuiltinTool, + SamplingParams, + ToolParamDefinition, + UserMessage, +) +from termcolor import cprint +from llama_toolchain.agentic_system.event_logger import EventLogger from .api import ( AgenticSystem, AgenticSystemCreateRequest, @@ -25,6 +32,7 @@ from .api import ( AgenticSystemToolDefinition, AgenticSystemTurnCreateRequest, AgenticSystemTurnResponseStreamChunk, + ToolPromptFormat, ) @@ -87,7 +95,7 @@ class AgenticSystemClient(AgenticSystem): async def run_main(host: str, port: int): # client to test remote impl of agentic system - api = await AgenticSystemClient(f"http://{host}:{port}") + api = AgenticSystemClient(f"http://{host}:{port}") tool_definitions = [ AgenticSystemToolDefinition( @@ -96,13 +104,28 @@ async def run_main(host: str, port: int): AgenticSystemToolDefinition( tool_name=BuiltinTool.wolfram_alpha, ), - AgenticSystemToolDefinition( - tool_name=BuiltinTool.photogen, - ), AgenticSystemToolDefinition( tool_name=BuiltinTool.code_interpreter, ), ] + tool_definitions += [ + AgenticSystemToolDefinition( + tool_name="get_boiling_point", + description="Get the boiling point of a imaginary liquids (eg. polyjuice)", + parameters={ + "liquid_name": ToolParamDefinition( + param_type="str", + description="The name of the liquid", + required=True, + ), + "celcius": ToolParamDefinition( + param_type="str", + description="Whether to return the boiling point in Celcius", + required=False, + ), + }, + ), + ] create_request = AgenticSystemCreateRequest( model="Meta-Llama3.1-8B-Instruct", @@ -114,12 +137,44 @@ async def run_main(host: str, port: int): output_shields=[], quantization_config=None, debug_prefix_messages=[], + tool_prompt_format=ToolPromptFormat.json, ), ) create_response = await api.create_agentic_system(create_request) print(create_response) - # TODO: Add chat session / turn apis to test e2e + + session_response = await api.create_agentic_system_session( + AgenticSystemSessionCreateRequest( + system_id=create_response.system_id, + session_name="test_session", + ) + ) + print(session_response) + + user_prompts = [ + "Who are you?", + "what is the 100th prime number?", + "Search web for who was 44th President of USA?", + "Write code to check if a number is prime. Use that to check if 7 is prime", + "What is the boiling point of polyjuicepotion ?", + ] + for content in user_prompts: + cprint(f"User> {content}", color="blue") + iterator = api.create_agentic_system_turn( + AgenticSystemTurnCreateRequest( + system_id=create_response.system_id, + session_id=session_response.session_id, + messages=[ + UserMessage(content=content), + ], + stream=True, + ) + ) + + async for event, log in EventLogger().log(iterator): + if log is not None: + log.print() def main(host: str, port: int): diff --git a/llama_toolchain/agentic_system/meta_reference/agent_instance.py b/llama_toolchain/agentic_system/meta_reference/agent_instance.py index 8e4555cb4..5be9f8bb6 100644 --- a/llama_toolchain/agentic_system/meta_reference/agent_instance.py +++ b/llama_toolchain/agentic_system/meta_reference/agent_instance.py @@ -10,6 +10,8 @@ import uuid from datetime import datetime from typing import AsyncGenerator, List, Optional +from termcolor import cprint + from llama_toolchain.agentic_system.api.datatypes import ( AgenticSystemInstanceConfig, AgenticSystemTurnResponseEvent, @@ -24,6 +26,7 @@ from llama_toolchain.agentic_system.api.datatypes import ( ShieldCallStep, StepType, ToolExecutionStep, + ToolPromptFormat, Turn, ) @@ -51,7 +54,6 @@ from llama_toolchain.safety.api.datatypes import ( ShieldDefinition, ShieldResponse, ) -from termcolor import cprint from llama_toolchain.agentic_system.api.endpoints import * # noqa from .safety import SafetyException, ShieldRunnerMixin @@ -74,6 +76,7 @@ class AgentInstance(ShieldRunnerMixin): output_shields: List[ShieldDefinition], max_infer_iters: int = 10, prefix_messages: Optional[List[Message]] = None, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, ): self.system_id = system_id self.instance_config = instance_config @@ -86,7 +89,9 @@ class AgentInstance(ShieldRunnerMixin): self.prefix_messages = prefix_messages else: self.prefix_messages = get_agentic_prefix_messages( - builtin_tools, custom_tool_definitions + builtin_tools, + custom_tool_definitions, + tool_prompt_format, ) for m in self.prefix_messages: diff --git a/llama_toolchain/agentic_system/meta_reference/agentic_system.py b/llama_toolchain/agentic_system/meta_reference/agentic_system.py index 5db8d6168..ae1d282aa 100644 --- a/llama_toolchain/agentic_system/meta_reference/agentic_system.py +++ b/llama_toolchain/agentic_system/meta_reference/agentic_system.py @@ -108,6 +108,7 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem): input_shields=cfg.input_shields, output_shields=cfg.output_shields, prefix_messages=cfg.debug_prefix_messages, + tool_prompt_format=cfg.tool_prompt_format, ) return AgenticSystemCreateResponse( diff --git a/llama_toolchain/agentic_system/meta_reference/safety.py b/llama_toolchain/agentic_system/meta_reference/safety.py index c78cb3028..ff3633f18 100644 --- a/llama_toolchain/agentic_system/meta_reference/safety.py +++ b/llama_toolchain/agentic_system/meta_reference/safety.py @@ -6,14 +6,15 @@ from typing import List -from llama_models.llama3_1.api.datatypes import Message, Role +from llama_models.llama3_1.api.datatypes import Message, Role, UserMessage +from termcolor import cprint + from llama_toolchain.safety.api.datatypes import ( OnViolationAction, ShieldDefinition, ShieldResponse, ) from llama_toolchain.safety.api.endpoints import RunShieldRequest, Safety -from termcolor import cprint class SafetyException(Exception): # noqa: N818 @@ -36,12 +37,11 @@ class ShieldRunnerMixin: async def run_shields( self, messages: List[Message], shields: List[ShieldDefinition] ) -> List[ShieldResponse]: + messages = messages.copy() # some shields like llama-guard require the first message to be a user message # since this might be a tool call, first role might not be user if len(messages) > 0 and messages[0].role != Role.user.value: - # TODO(ashwin): we need to change the type of the message, this kind of modification - # is no longer appropriate - messages[0].role = Role.user.value + messages[0] = UserMessage(content=messages[0].content) res = await self.safety_api.run_shields( RunShieldRequest( diff --git a/llama_toolchain/agentic_system/meta_reference/system_prompt.py b/llama_toolchain/agentic_system/meta_reference/system_prompt.py index c8c616285..9db3218c1 100644 --- a/llama_toolchain/agentic_system/meta_reference/system_prompt.py +++ b/llama_toolchain/agentic_system/meta_reference/system_prompt.py @@ -5,21 +5,27 @@ # the root directory of this source tree. import json +import textwrap from datetime import datetime from typing import List +from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat + from llama_toolchain.inference.api import ( BuiltinTool, Message, SystemMessage, ToolDefinition, + UserMessage, ) from .tools.builtin import SingleMessageBuiltinTool def get_agentic_prefix_messages( - builtin_tools: List[SingleMessageBuiltinTool], custom_tools: List[ToolDefinition] + builtin_tools: List[SingleMessageBuiltinTool], + custom_tools: List[ToolDefinition], + tool_prompt_format: ToolPromptFormat, ) -> List[Message]: messages = [] content = "" @@ -34,28 +40,52 @@ def get_agentic_prefix_messages( ] ) if tool_str: - content += f"Tools: {tool_str}\n" + content += f"Tools: {tool_str}" current_date = datetime.now() formatted_date = current_date.strftime("%d %B %Y") date_str = f""" Cutting Knowledge Date: December 2023 -Today Date: {formatted_date}\n\n""" +Today Date: {formatted_date}\n""" content += date_str + messages.append(SystemMessage(content=content)) if custom_tools: - custom_message = get_system_prompt_for_custom_tools(custom_tools) - content += custom_message + if tool_prompt_format == ToolPromptFormat.function_tag: + text = prompt_for_function_tag(custom_tools) + messages.append(UserMessage(content=text)) + elif tool_prompt_format == ToolPromptFormat.json: + text = prompt_for_json(custom_tools) + messages.append(UserMessage(content=text)) + else: + raise NotImplementedError( + f"Tool prompt format {tool_prompt_format} is not supported" + ) + else: + messages.append(SystemMessage(content=content)) - # TODO: Replace this hard coded message with instructions coming in the request - if False: - content += "You are a helpful Assistant." - - messages.append(SystemMessage(content=content)) return messages -def get_system_prompt_for_custom_tools(custom_tools: List[ToolDefinition]) -> str: +def prompt_for_json(custom_tools: List[ToolDefinition]) -> str: + tool_defs = "\n".join( + translate_custom_tool_definition_to_json(t) for t in custom_tools + ) + content = textwrap.dedent( + """ + Answer the user's question by making use of the following functions if needed. + If none of the function can be used, please say so. + Here is a list of functions in JSON format: + {tool_defs} + + Return function calls in JSON format. + """ + ) + content = content.lstrip("\n").format(tool_defs=tool_defs) + return content + + +def prompt_for_function_tag(custom_tools: List[ToolDefinition]) -> str: custom_tool_params = "" for t in custom_tools: custom_tool_params += get_instruction_string(t) + "\n" @@ -76,7 +106,6 @@ Reminder: - Required parameters MUST be specified - Only call one function at a time - Put the entire function call reply on one line - """ return content @@ -98,7 +127,6 @@ def get_parameters_string(custom_tool_definition) -> str: ) -# NOTE: Unused right now def translate_custom_tool_definition_to_json(tool_def): """Translates ToolDefinition to json as expected by model eg. output for a function @@ -149,4 +177,4 @@ def translate_custom_tool_definition_to_json(tool_def): else: func_def["function"]["parameters"] = {} - return json.dumps(func_def) + return json.dumps(func_def, indent=4) diff --git a/llama_toolchain/agentic_system/utils.py b/llama_toolchain/agentic_system/utils.py index bc1639b3d..3ae5c67b6 100644 --- a/llama_toolchain/agentic_system/utils.py +++ b/llama_toolchain/agentic_system/utils.py @@ -15,6 +15,7 @@ from llama_toolchain.agentic_system.api import ( AgenticSystemSessionCreateRequest, AgenticSystemToolDefinition, ) +from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat from llama_toolchain.agentic_system.client import AgenticSystemClient from llama_toolchain.agentic_system.tools.custom.execute import ( @@ -64,6 +65,7 @@ async def get_agent_system_instance( custom_tools: Optional[List[Any]] = None, disable_safety: bool = False, model: str = "Meta-Llama3.1-8B-Instruct", + tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, ) -> AgenticSystemClientWrapper: custom_tools = custom_tools or [] @@ -113,6 +115,7 @@ async def get_agent_system_instance( ] ), sampling_params=SamplingParams(), + tool_prompt_format=tool_prompt_format, ), ) create_response = await api.create_agentic_system(create_request) diff --git a/llama_toolchain/safety/api/datatypes.py b/llama_toolchain/safety/api/datatypes.py index c5734da99..c0d23f589 100644 --- a/llama_toolchain/safety/api/datatypes.py +++ b/llama_toolchain/safety/api/datatypes.py @@ -8,12 +8,11 @@ from enum import Enum from typing import Dict, Optional, Union from llama_models.llama3_1.api.datatypes import ToolParamDefinition - from llama_models.schema_utils import json_schema_type -from llama_toolchain.common.deployment_types import RestAPIExecutionConfig +from pydantic import BaseModel, validator -from pydantic import BaseModel +from llama_toolchain.common.deployment_types import RestAPIExecutionConfig @json_schema_type @@ -43,6 +42,16 @@ class ShieldDefinition(BaseModel): on_violation_action: OnViolationAction = OnViolationAction.RAISE execution_config: Optional[RestAPIExecutionConfig] = None + @validator("shield_type", pre=True) + @classmethod + def validate_field(cls, v): + if isinstance(v, str): + try: + return BuiltinShield(v) + except ValueError: + return v + return v + @json_schema_type class ShieldResponse(BaseModel): @@ -51,3 +60,13 @@ class ShieldResponse(BaseModel): is_violation: bool violation_type: Optional[str] = None violation_return_message: Optional[str] = None + + @validator("shield_type", pre=True) + @classmethod + def validate_field(cls, v): + if isinstance(v, str): + try: + return BuiltinShield(v) + except ValueError: + return v + return v