diff --git a/llama_toolchain/agentic_system/api/datatypes.py b/llama_toolchain/agentic_system/api/datatypes.py index 825859b63..e24a7e947 100644 --- a/llama_toolchain/agentic_system/api/datatypes.py +++ b/llama_toolchain/agentic_system/api/datatypes.py @@ -134,7 +134,7 @@ class AgenticSystemInstanceConfig(BaseModel): # this is debug only debug_prefix_messages: Optional[List[Message]] = Field(default_factory=list) tool_prompt_format: Optional[ToolPromptFormat] = Field( - default=ToolPromptFormat.function_tag + default=ToolPromptFormat.json ) diff --git a/llama_toolchain/agentic_system/client.py b/llama_toolchain/agentic_system/client.py index 6b6dc6106..3066255ea 100644 --- a/llama_toolchain/agentic_system/client.py +++ b/llama_toolchain/agentic_system/client.py @@ -19,7 +19,9 @@ from llama_models.llama3_1.api.datatypes import ( ToolParamDefinition, UserMessage, ) +from termcolor import cprint +from llama_toolchain.agentic_system.event_logger import EventLogger from .api import ( AgenticSystem, AgenticSystemCreateRequest, @@ -120,7 +122,18 @@ async def run_main(host: str, port: int): required=True, ) }, - ) + ), + AgenticSystemToolDefinition( + tool_name="custom_tool_2", + description="a second custom tool", + parameters={ + "param2": ToolParamDefinition( + param_type="str", + description="a string parameter", + required=True, + ) + }, + ), ] create_request = AgenticSystemCreateRequest( @@ -138,7 +151,7 @@ async def run_main(host: str, port: int): ) create_response = await api.create_agentic_system(create_request) - print("Create Response -->", create_response) + print(create_response) session_response = await api.create_agentic_system_session( AgenticSystemSessionCreateRequest( @@ -146,21 +159,28 @@ async def run_main(host: str, port: int): session_name="test_session", ) ) - print("Session Response -->", session_response) + print(session_response) - turn_response = api.create_agentic_system_turn( - AgenticSystemTurnCreateRequest( - system_id=create_response.system_id, - session_id=session_response.session_id, - messages=[ - UserMessage(content="Who are you?"), - ], - stream=False, + user_prompts = [ + "Who are you?", + "Write code to check if a number is prime. Use that to check if 7 is prime", + ] + for content in user_prompts: + cprint(f"User> {content}", color="blue") + iterator = api.create_agentic_system_turn( + AgenticSystemTurnCreateRequest( + system_id=create_response.system_id, + session_id=session_response.session_id, + messages=[ + UserMessage(content=content), + ], + stream=True, + ) ) - ) - print("Turn Response -->") - async for chunk in turn_response: - print(chunk) + + async for event, log in EventLogger().log(iterator): + if log is not None: + log.print() def main(host: str, port: int): diff --git a/llama_toolchain/agentic_system/meta_reference/agent_instance.py b/llama_toolchain/agentic_system/meta_reference/agent_instance.py index 69784504f..5be9f8bb6 100644 --- a/llama_toolchain/agentic_system/meta_reference/agent_instance.py +++ b/llama_toolchain/agentic_system/meta_reference/agent_instance.py @@ -76,7 +76,7 @@ class AgentInstance(ShieldRunnerMixin): output_shields: List[ShieldDefinition], max_infer_iters: int = 10, prefix_messages: Optional[List[Message]] = None, - tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.function_tag, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, ): self.system_id = system_id self.instance_config = instance_config diff --git a/llama_toolchain/agentic_system/meta_reference/safety.py b/llama_toolchain/agentic_system/meta_reference/safety.py index c78cb3028..ff3633f18 100644 --- a/llama_toolchain/agentic_system/meta_reference/safety.py +++ b/llama_toolchain/agentic_system/meta_reference/safety.py @@ -6,14 +6,15 @@ from typing import List -from llama_models.llama3_1.api.datatypes import Message, Role +from llama_models.llama3_1.api.datatypes import Message, Role, UserMessage +from termcolor import cprint + from llama_toolchain.safety.api.datatypes import ( OnViolationAction, ShieldDefinition, ShieldResponse, ) from llama_toolchain.safety.api.endpoints import RunShieldRequest, Safety -from termcolor import cprint class SafetyException(Exception): # noqa: N818 @@ -36,12 +37,11 @@ class ShieldRunnerMixin: async def run_shields( self, messages: List[Message], shields: List[ShieldDefinition] ) -> List[ShieldResponse]: + messages = messages.copy() # some shields like llama-guard require the first message to be a user message # since this might be a tool call, first role might not be user if len(messages) > 0 and messages[0].role != Role.user.value: - # TODO(ashwin): we need to change the type of the message, this kind of modification - # is no longer appropriate - messages[0].role = Role.user.value + messages[0] = UserMessage(content=messages[0].content) res = await self.safety_api.run_shields( RunShieldRequest( diff --git a/llama_toolchain/agentic_system/meta_reference/system_prompt.py b/llama_toolchain/agentic_system/meta_reference/system_prompt.py index a6fa8e638..f5792d22b 100644 --- a/llama_toolchain/agentic_system/meta_reference/system_prompt.py +++ b/llama_toolchain/agentic_system/meta_reference/system_prompt.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import json +import textwrap from datetime import datetime from typing import List @@ -15,6 +16,7 @@ from llama_toolchain.inference.api import ( Message, SystemMessage, ToolDefinition, + UserMessage, ) from .tools.builtin import SingleMessageBuiltinTool @@ -49,18 +51,43 @@ Today Date: {formatted_date}\n""" if custom_tools: if tool_prompt_format == ToolPromptFormat.function_tag: - custom_message = get_system_prompt_for_custom_tools(custom_tools) + custom_message = prompt_for_function_tag(custom_tools) content += custom_message messages.append(SystemMessage(content=content)) + elif tool_prompt_format == ToolPromptFormat.json: + messages.append(SystemMessage(content=content)) + # json is added as a user prompt + text = prompt_for_json(custom_tools) + messages.append(UserMessage(content=text)) else: raise NotImplementedError( f"Tool prompt format {tool_prompt_format} is not supported" ) + else: + messages.append(SystemMessage(content=content)) return messages -def get_system_prompt_for_custom_tools(custom_tools: List[ToolDefinition]) -> str: +def prompt_for_json(custom_tools: List[ToolDefinition]) -> str: + tool_defs = "\n".join( + translate_custom_tool_definition_to_json(t) for t in custom_tools + ) + content = textwrap.dedent( + """ + Answer the user's question by making use of the following functions if needed. + If none of the function can be used, please say so. + Here is a list of functions in JSON format: + {tool_defs} + + Return function calls in json format. + """ + ) + content = content.lstrip("\n").format(tool_defs=tool_defs) + return content + + +def prompt_for_function_tag(custom_tools: List[ToolDefinition]) -> str: custom_tool_params = "" for t in custom_tools: custom_tool_params += get_instruction_string(t) + "\n" @@ -102,7 +129,6 @@ def get_parameters_string(custom_tool_definition) -> str: ) -# NOTE: Unused right now def translate_custom_tool_definition_to_json(tool_def): """Translates ToolDefinition to json as expected by model eg. output for a function @@ -153,4 +179,4 @@ def translate_custom_tool_definition_to_json(tool_def): else: func_def["function"]["parameters"] = {} - return json.dumps(func_def) + return json.dumps(func_def, indent=4) diff --git a/llama_toolchain/agentic_system/utils.py b/llama_toolchain/agentic_system/utils.py index bc1639b3d..3ae5c67b6 100644 --- a/llama_toolchain/agentic_system/utils.py +++ b/llama_toolchain/agentic_system/utils.py @@ -15,6 +15,7 @@ from llama_toolchain.agentic_system.api import ( AgenticSystemSessionCreateRequest, AgenticSystemToolDefinition, ) +from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat from llama_toolchain.agentic_system.client import AgenticSystemClient from llama_toolchain.agentic_system.tools.custom.execute import ( @@ -64,6 +65,7 @@ async def get_agent_system_instance( custom_tools: Optional[List[Any]] = None, disable_safety: bool = False, model: str = "Meta-Llama3.1-8B-Instruct", + tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, ) -> AgenticSystemClientWrapper: custom_tools = custom_tools or [] @@ -113,6 +115,7 @@ async def get_agent_system_instance( ] ), sampling_params=SamplingParams(), + tool_prompt_format=tool_prompt_format, ), ) create_response = await api.create_agentic_system(create_request) diff --git a/llama_toolchain/safety/api/datatypes.py b/llama_toolchain/safety/api/datatypes.py index c5734da99..c0d23f589 100644 --- a/llama_toolchain/safety/api/datatypes.py +++ b/llama_toolchain/safety/api/datatypes.py @@ -8,12 +8,11 @@ from enum import Enum from typing import Dict, Optional, Union from llama_models.llama3_1.api.datatypes import ToolParamDefinition - from llama_models.schema_utils import json_schema_type -from llama_toolchain.common.deployment_types import RestAPIExecutionConfig +from pydantic import BaseModel, validator -from pydantic import BaseModel +from llama_toolchain.common.deployment_types import RestAPIExecutionConfig @json_schema_type @@ -43,6 +42,16 @@ class ShieldDefinition(BaseModel): on_violation_action: OnViolationAction = OnViolationAction.RAISE execution_config: Optional[RestAPIExecutionConfig] = None + @validator("shield_type", pre=True) + @classmethod + def validate_field(cls, v): + if isinstance(v, str): + try: + return BuiltinShield(v) + except ValueError: + return v + return v + @json_schema_type class ShieldResponse(BaseModel): @@ -51,3 +60,13 @@ class ShieldResponse(BaseModel): is_violation: bool violation_type: Optional[str] = None violation_return_message: Optional[str] = None + + @validator("shield_type", pre=True) + @classmethod + def validate_field(cls, v): + if isinstance(v, str): + try: + return BuiltinShield(v) + except ValueError: + return v + return v