diff --git a/llama_toolchain/agentic_system/api/datatypes.py b/llama_toolchain/agentic_system/api/datatypes.py index 648aed698..689abeceb 100644 --- a/llama_toolchain/agentic_system/api/datatypes.py +++ b/llama_toolchain/agentic_system/api/datatypes.py @@ -110,35 +110,6 @@ class Session(BaseModel): started_at: datetime -@json_schema_type -class ToolPromptFormat(Enum): - """This Enum refers to the prompt format for calling zero shot tools - - `json` -- - Refers to the json format for calling tools. - The json format takes the form like - { - "type": "function", - "function" : { - "name": "function_name", - "description": "function_description", - "parameters": {...} - } - } - - `function_tag` -- - This is an example of how you could define - your own user defined format for making tool calls. - The function_tag format looks like this, - (parameters) - - The detailed prompts for each of these formats are defined in `system_prompt.py` - """ - - json = "json" - function_tag = "function_tag" - - @json_schema_type class AgenticSystemInstanceConfig(BaseModel): instructions: str diff --git a/llama_toolchain/agentic_system/meta_reference/agent_instance.py b/llama_toolchain/agentic_system/meta_reference/agent_instance.py index 5be9f8bb6..5de17d7b9 100644 --- a/llama_toolchain/agentic_system/meta_reference/agent_instance.py +++ b/llama_toolchain/agentic_system/meta_reference/agent_instance.py @@ -56,10 +56,10 @@ from llama_toolchain.safety.api.datatypes import ( ) from llama_toolchain.agentic_system.api.endpoints import * # noqa +from llama_toolchain.tools.base import BaseTool +from llama_toolchain.tools.builtin import SingleMessageBuiltinTool + from .safety import SafetyException, ShieldRunnerMixin -from .system_prompt import get_agentic_prefix_messages -from .tools.base import BaseTool -from .tools.builtin import SingleMessageBuiltinTool class AgentInstance(ShieldRunnerMixin): @@ -85,18 +85,6 @@ class AgentInstance(ShieldRunnerMixin): self.inference_api = inference_api self.safety_api = safety_api - if prefix_messages is not None and len(prefix_messages) > 0: - self.prefix_messages = prefix_messages - else: - self.prefix_messages = get_agentic_prefix_messages( - builtin_tools, - custom_tool_definitions, - tool_prompt_format, - ) - - for m in self.prefix_messages: - print(m.content) - self.max_infer_iters = max_infer_iters self.tools_dict = {t.get_name(): t for t in builtin_tools} @@ -344,7 +332,7 @@ class AgentInstance(ShieldRunnerMixin): stream: bool = False, max_gen_len: Optional[int] = None, ) -> AsyncGenerator: - input_messages = preprocess_dialog(input_messages, self.prefix_messages) + input_messages = preprocess_dialog(input_messages) attachments = [] @@ -373,7 +361,8 @@ class AgentInstance(ShieldRunnerMixin): req = ChatCompletionRequest( model=self.model, messages=input_messages, - available_tools=self.instance_config.available_tools, + tools=self.instance_config.available_tools, + tool_prompt_format=self.instance_config.tool_prompt_format, stream=True, sampling_params=SamplingParams( temperature=temperature, @@ -601,14 +590,12 @@ def attachment_message(url: URL) -> ToolResponseMessage: ) -def preprocess_dialog( - messages: List[Message], prefix_messages: List[Message] -) -> List[Message]: +def preprocess_dialog(messages: List[Message]) -> List[Message]: """ Preprocesses the dialog by removing the system message and adding the system message to the beginning of the dialog. """ - ret = prefix_messages.copy() + ret = [] for m in messages: if m.role == Role.system.value: diff --git a/llama_toolchain/agentic_system/meta_reference/agentic_system.py b/llama_toolchain/agentic_system/meta_reference/agentic_system.py index 5252e7515..0d3f33507 100644 --- a/llama_toolchain/agentic_system/meta_reference/agentic_system.py +++ b/llama_toolchain/agentic_system/meta_reference/agentic_system.py @@ -24,17 +24,17 @@ from llama_toolchain.agentic_system.api import ( AgenticSystemTurnCreateRequest, ) -from .agent_instance import AgentInstance - -from .config import AgenticSystemConfig - -from .tools.builtin import ( +from llama_toolchain.tools.builtin import ( BraveSearchTool, CodeInterpreterTool, PhotogenTool, WolframAlphaTool, ) -from .tools.safety import with_safety +from llama_toolchain.tools.safety import with_safety + +from .agent_instance import AgentInstance + +from .config import AgenticSystemConfig logger = logging.getLogger() diff --git a/llama_toolchain/agentic_system/tools/custom/execute.py b/llama_toolchain/agentic_system/meta_reference/execute_with_custom_tools.py similarity index 100% rename from llama_toolchain/agentic_system/tools/custom/execute.py rename to llama_toolchain/agentic_system/meta_reference/execute_with_custom_tools.py diff --git a/llama_toolchain/agentic_system/utils.py b/llama_toolchain/agentic_system/utils.py index 9613b45df..b2ba4fec8 100644 --- a/llama_toolchain/agentic_system/utils.py +++ b/llama_toolchain/agentic_system/utils.py @@ -18,7 +18,7 @@ from llama_toolchain.agentic_system.api import ( from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat from llama_toolchain.agentic_system.client import AgenticSystemClient -from llama_toolchain.agentic_system.tools.custom.execute import ( +from llama_toolchain.agentic_system.meta_reference.execute_with_custom_tools import ( execute_with_custom_tools, ) from llama_toolchain.safety.api.datatypes import BuiltinShield, ShieldDefinition diff --git a/llama_toolchain/inference/api/datatypes.py b/llama_toolchain/inference/api/datatypes.py index 571ecc3ea..cad8f4377 100644 --- a/llama_toolchain/inference/api/datatypes.py +++ b/llama_toolchain/inference/api/datatypes.py @@ -15,6 +15,41 @@ from typing_extensions import Annotated from llama_models.llama3.api.datatypes import * # noqa: F403 +@json_schema_type +class ToolChoice(Enum): + auto = "auto" + required = "required" + + +@json_schema_type +class ToolPromptFormat(Enum): + """This Enum refers to the prompt format for calling zero shot tools + + `json` -- + Refers to the json format for calling tools. + The json format takes the form like + { + "type": "function", + "function" : { + "name": "function_name", + "description": "function_description", + "parameters": {...} + } + } + + `function_tag` -- + This is an example of how you could define + your own user defined format for making tool calls. + The function_tag format looks like this, + (parameters) + + The detailed prompts for each of these formats are defined in `system_prompt.py` + """ + + json = "json" + function_tag = "function_tag" + + class LogProbConfig(BaseModel): top_k: Optional[int] = 0 diff --git a/llama_toolchain/inference/api/endpoints.py b/llama_toolchain/inference/api/endpoints.py index ef1c7b159..26773e439 100644 --- a/llama_toolchain/inference/api/endpoints.py +++ b/llama_toolchain/inference/api/endpoints.py @@ -7,6 +7,8 @@ from .datatypes import * # noqa: F403 from typing import Optional, Protocol +from llama_models.llama3.api.datatypes import ToolDefinition + # this dependency is annoying and we need a forked up version anyway from llama_models.schema_utils import webmethod @@ -56,7 +58,11 @@ class ChatCompletionRequest(BaseModel): sampling_params: Optional[SamplingParams] = SamplingParams() # zero-shot tool definitions as input to the model - available_tools: Optional[List[ToolDefinition]] = Field(default_factory=list) + tools: Optional[List[ToolDefinition]] = Field(default_factory=list) + tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) + tool_prompt_format: Optional[ToolPromptFormat] = Field( + default=ToolPromptFormat.json + ) stream: Optional[bool] = False logprobs: Optional[LogProbConfig] = None @@ -82,8 +88,11 @@ class BatchChatCompletionRequest(BaseModel): sampling_params: Optional[SamplingParams] = SamplingParams() # zero-shot tool definitions as input to the model - available_tools: Optional[List[ToolDefinition]] = Field(default_factory=list) - + tools: Optional[List[ToolDefinition]] = Field(default_factory=list) + tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) + tool_prompt_format: Optional[ToolPromptFormat] = Field( + default=ToolPromptFormat.json + ) logprobs: Optional[LogProbConfig] = None diff --git a/llama_toolchain/inference/meta_reference/inference.py b/llama_toolchain/inference/meta_reference/inference.py index 84caf1ecf..dc674a25b 100644 --- a/llama_toolchain/inference/meta_reference/inference.py +++ b/llama_toolchain/inference/meta_reference/inference.py @@ -22,7 +22,7 @@ from llama_toolchain.inference.api import ( ToolCallDelta, ToolCallParseStatus, ) - +from llama_toolchain.inference.prepare_messages import prepare_messages_for_tools from .config import MetaReferenceImplConfig from .model_parallel import LlamaModelParallelGenerator @@ -67,6 +67,7 @@ class MetaReferenceInferenceImpl(Inference): ) -> AsyncIterator[ Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse] ]: + request = prepare_messages_for_tools(request) model = resolve_model(request.model) if model is None: raise RuntimeError( diff --git a/llama_toolchain/inference/ollama/ollama.py b/llama_toolchain/inference/ollama/ollama.py index 8901d5c02..8bfd38a71 100644 --- a/llama_toolchain/inference/ollama/ollama.py +++ b/llama_toolchain/inference/ollama/ollama.py @@ -32,7 +32,7 @@ from llama_toolchain.inference.api import ( ToolCallDelta, ToolCallParseStatus, ) - +from llama_toolchain.inference.prepare_messages import prepare_messages_for_tools from .config import OllamaImplConfig # TODO: Eventually this will move to the llama cli model list command @@ -111,6 +111,7 @@ class OllamaInference(Inference): return options async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: + request = prepare_messages_for_tools(request) # accumulate sampling params and other options to pass to ollama options = self.get_ollama_chat_options(request) ollama_model = self.resolve_ollama_model(request.model) diff --git a/llama_toolchain/agentic_system/meta_reference/system_prompt.py b/llama_toolchain/inference/prepare_messages.py similarity index 59% rename from llama_toolchain/agentic_system/meta_reference/system_prompt.py rename to llama_toolchain/inference/prepare_messages.py index 9db3218c1..e23bbbe8f 100644 --- a/llama_toolchain/agentic_system/meta_reference/system_prompt.py +++ b/llama_toolchain/inference/prepare_messages.py @@ -1,70 +1,90 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - import json +import os import textwrap + from datetime import datetime -from typing import List - -from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat - -from llama_toolchain.inference.api import ( - BuiltinTool, - Message, - SystemMessage, - ToolDefinition, - UserMessage, +from llama_toolchain.inference.api import * # noqa: F403 +from llama_toolchain.tools.builtin import ( + BraveSearchTool, + CodeInterpreterTool, + PhotogenTool, + WolframAlphaTool, ) -from .tools.builtin import SingleMessageBuiltinTool + +def tool_breakdown(tools: List[ToolDefinition]) -> str: + builtin_tools, custom_tools = [], [] + for dfn in tools: + if isinstance(dfn.tool_name, BuiltinTool): + builtin_tools.append(dfn) + else: + custom_tools.append(dfn) + + return builtin_tools, custom_tools -def get_agentic_prefix_messages( - builtin_tools: List[SingleMessageBuiltinTool], - custom_tools: List[ToolDefinition], - tool_prompt_format: ToolPromptFormat, -) -> List[Message]: +def prepare_messages_for_tools(request: ChatCompletionRequest) -> ChatCompletionRequest: + """This functions takes a ChatCompletionRequest and returns an augmented request. + The request's messages are augmented to update the system message + corresponding to the tool definitions provided in the request. + """ + assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported" + + existing_messages = request.messages + + existing_system_message = None + if existing_messages[0].role == Role.system.value: + existing_system_message = existing_messages.pop(0) + + builtin_tools, custom_tools = tool_breakdown(request.tools) + messages = [] content = "" - if builtin_tools: + if builtin_tools or custom_tools: content += "Environment: ipython\n" + if builtin_tools: tool_str = ", ".join( [ - t.get_name() + t.tool_name.value for t in builtin_tools - if t.get_name() != BuiltinTool.code_interpreter.value + if t.tool_name != BuiltinTool.code_interpreter ] ) if tool_str: - content += f"Tools: {tool_str}" + content += f"Tools: {tool_str}\n" current_date = datetime.now() formatted_date = current_date.strftime("%d %B %Y") - date_str = f""" -Cutting Knowledge Date: December 2023 -Today Date: {formatted_date}\n""" - content += date_str + date_str = textwrap.dedent( + f""" + Cutting Knowledge Date: December 2023 + Today Date: {formatted_date} + """ + ) + content += date_str.lstrip("\n") + + if existing_system_message: + content += "\n" + content += existing_system_message.content + messages.append(SystemMessage(content=content)) if custom_tools: - if tool_prompt_format == ToolPromptFormat.function_tag: + if request.tool_prompt_format == ToolPromptFormat.function_tag: text = prompt_for_function_tag(custom_tools) messages.append(UserMessage(content=text)) - elif tool_prompt_format == ToolPromptFormat.json: + elif request.tool_prompt_format == ToolPromptFormat.json: text = prompt_for_json(custom_tools) messages.append(UserMessage(content=text)) else: raise NotImplementedError( f"Tool prompt format {tool_prompt_format} is not supported" ) - else: - messages.append(SystemMessage(content=content)) - return messages + messages += existing_messages + request.messages = messages + return request def prompt_for_json(custom_tools: List[ToolDefinition]) -> str: @@ -91,23 +111,26 @@ def prompt_for_function_tag(custom_tools: List[ToolDefinition]) -> str: custom_tool_params += get_instruction_string(t) + "\n" custom_tool_params += get_parameters_string(t) + "\n\n" - content = f""" -You have access to the following functions: + content = textwrap.dedent( + """ + You have access to the following functions: -{custom_tool_params} -Think very carefully before calling functions. -If you choose to call a function ONLY reply in the following format with no prefix or suffix: + {custom_tool_params} + Think very carefully before calling functions. + If you choose to call a function ONLY reply in the following format with no prefix or suffix: -{{"example_name": "example_value"}} + {{"example_name": "example_value"}} -Reminder: -- If looking for real time information use relevant functions before falling back to brave_search -- Function calls MUST follow the specified format, start with -- Required parameters MUST be specified -- Only call one function at a time -- Put the entire function call reply on one line -""" - return content + Reminder: + - If looking for real time information use relevant functions before falling back to brave_search + - Function calls MUST follow the specified format, start with + - Required parameters MUST be specified + - Only call one function at a time + - Put the entire function call reply on one line + """ + ) + + return content.lstrip("\n").format(custom_tool_params=custom_tool_params) def get_instruction_string(custom_tool_definition) -> str: diff --git a/llama_toolchain/agentic_system/meta_reference/tools/__init__.py b/llama_toolchain/tools/__init__.py similarity index 100% rename from llama_toolchain/agentic_system/meta_reference/tools/__init__.py rename to llama_toolchain/tools/__init__.py diff --git a/llama_toolchain/agentic_system/meta_reference/tools/base.py b/llama_toolchain/tools/base.py similarity index 100% rename from llama_toolchain/agentic_system/meta_reference/tools/base.py rename to llama_toolchain/tools/base.py diff --git a/llama_toolchain/agentic_system/meta_reference/tools/builtin.py b/llama_toolchain/tools/builtin.py similarity index 100% rename from llama_toolchain/agentic_system/meta_reference/tools/builtin.py rename to llama_toolchain/tools/builtin.py diff --git a/llama_toolchain/agentic_system/meta_reference/tools/ipython_tool/__init__.py b/llama_toolchain/tools/custom/__init__.py similarity index 100% rename from llama_toolchain/agentic_system/meta_reference/tools/ipython_tool/__init__.py rename to llama_toolchain/tools/custom/__init__.py diff --git a/llama_toolchain/agentic_system/tools/custom/datatypes.py b/llama_toolchain/tools/custom/datatypes.py similarity index 96% rename from llama_toolchain/agentic_system/tools/custom/datatypes.py rename to llama_toolchain/tools/custom/datatypes.py index 174b55241..d2a97376d 100644 --- a/llama_toolchain/agentic_system/tools/custom/datatypes.py +++ b/llama_toolchain/tools/custom/datatypes.py @@ -13,9 +13,7 @@ from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_toolchain.agentic_system.api import * # noqa: F403 # TODO: this is symptomatic of us needing to pull more tooling related utilities -from llama_toolchain.agentic_system.meta_reference.tools.builtin import ( - interpret_content_as_attachment, -) +from llama_toolchain.tools.builtin import interpret_content_as_attachment class CustomTool: diff --git a/llama_toolchain/agentic_system/tools/custom/__init__.py b/llama_toolchain/tools/ipython_tool/__init__.py similarity index 100% rename from llama_toolchain/agentic_system/tools/custom/__init__.py rename to llama_toolchain/tools/ipython_tool/__init__.py diff --git a/llama_toolchain/agentic_system/meta_reference/tools/ipython_tool/code_env_prefix.py b/llama_toolchain/tools/ipython_tool/code_env_prefix.py similarity index 100% rename from llama_toolchain/agentic_system/meta_reference/tools/ipython_tool/code_env_prefix.py rename to llama_toolchain/tools/ipython_tool/code_env_prefix.py diff --git a/llama_toolchain/agentic_system/meta_reference/tools/ipython_tool/code_execution.py b/llama_toolchain/tools/ipython_tool/code_execution.py similarity index 100% rename from llama_toolchain/agentic_system/meta_reference/tools/ipython_tool/code_execution.py rename to llama_toolchain/tools/ipython_tool/code_execution.py diff --git a/llama_toolchain/agentic_system/meta_reference/tools/ipython_tool/matplotlib_custom_backend.py b/llama_toolchain/tools/ipython_tool/matplotlib_custom_backend.py similarity index 100% rename from llama_toolchain/agentic_system/meta_reference/tools/ipython_tool/matplotlib_custom_backend.py rename to llama_toolchain/tools/ipython_tool/matplotlib_custom_backend.py diff --git a/llama_toolchain/agentic_system/meta_reference/tools/ipython_tool/utils.py b/llama_toolchain/tools/ipython_tool/utils.py similarity index 100% rename from llama_toolchain/agentic_system/meta_reference/tools/ipython_tool/utils.py rename to llama_toolchain/tools/ipython_tool/utils.py diff --git a/llama_toolchain/agentic_system/meta_reference/tools/safety.py b/llama_toolchain/tools/safety.py similarity index 100% rename from llama_toolchain/agentic_system/meta_reference/tools/safety.py rename to llama_toolchain/tools/safety.py diff --git a/tests/example_custom_tool.py b/tests/example_custom_tool.py new file mode 100644 index 000000000..ec338982e --- /dev/null +++ b/tests/example_custom_tool.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Dict + +from llama_models.llama3.api.datatypes import ToolParamDefinition +from llama_toolchain.tools.custom.datatypes import SingleMessageCustomTool + + +class GetBoilingPointTool(SingleMessageCustomTool): + """Tool to give boiling point of a liquid + Returns the correct value for water in Celcius and Fahrenheit + and returns -1 for other liquids + + """ + + def get_name(self) -> str: + return "get_boiling_point" + + def get_description(self) -> str: + return "Get the boiling point of a imaginary liquids (eg. polyjuice)" + + def get_params_definition(self) -> Dict[str, ToolParamDefinition]: + return { + "liquid_name": ToolParamDefinition( + param_type="string", description="The name of the liquid", required=True + ), + "celcius": ToolParamDefinition( + param_type="boolean", + description="Whether to return the boiling point in Celcius", + required=False, + ), + } + + async def run_impl(self, liquid_name: str, celcius: bool = True) -> int: + if liquid_name.lower() == "polyjuice": + if celcius: + return -100 + else: + return -212 + else: + return -1 diff --git a/tests/test_e2e.py b/tests/test_e2e.py new file mode 100644 index 000000000..41afb9db0 --- /dev/null +++ b/tests/test_e2e.py @@ -0,0 +1,183 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Run from top level dir as: +# PYTHONPATH=. python3 tests/test_e2e.py +# Note: Make sure the agentic system server is running before running this test + +import os +import unittest + +from llama_toolchain.agentic_system.event_logger import EventLogger, LogEvent +from llama_toolchain.agentic_system.utils import get_agent_system_instance + +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_toolchain.agentic_system.api.datatypes import StepType, ToolPromptFormat +from llama_toolchain.tools.custom.datatypes import CustomTool + +from tests.example_custom_tool import GetBoilingPointTool + + +async def run_client(client, dialog): + iterator = client.run(dialog, stream=False) + async for _event, log in EventLogger().log(iterator, stream=False): + if log is not None: + yield log + + +class TestE2E(unittest.IsolatedAsyncioTestCase): + + HOST = "localhost" + PORT = os.environ.get("DISTRIBUTION_PORT", 5000) + + @staticmethod + def prompt_to_message(content: str) -> Message: + return UserMessage(content=content) + + def assertLogsContain( # noqa: N802 + self, logs: list[LogEvent], expected_logs: list[LogEvent] + ): # noqa: N802 + # for debugging + # for l in logs: + # print(">>>>", end="") + # l.print() + self.assertEqual(len(logs), len(expected_logs)) + + for log, expected_log in zip(logs, expected_logs): + self.assertEqual(log.role, expected_log.role) + self.assertIn(expected_log.content.lower(), log.content.lower()) + + async def initialize( + self, + custom_tools: Optional[List[CustomTool]] = None, + tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, + ): + client = await get_agent_system_instance( + host=TestE2E.HOST, + port=TestE2E.PORT, + custom_tools=custom_tools, + # model="Meta-Llama3.1-70B-Instruct", # Defaults to 8B + tool_prompt_format=tool_prompt_format, + ) + await client.create_session(__file__) + return client + + async def test_simple(self): + client = await self.initialize() + dialog = [ + TestE2E.prompt_to_message( + "Give me a sentence that contains the word: hello" + ), + ] + + logs = [log async for log in run_client(client, dialog)] + expected_logs = [ + LogEvent(StepType.shield_call, "No Violation"), + LogEvent(StepType.inference, "hello"), + LogEvent(StepType.shield_call, "No Violation"), + ] + + self.assertLogsContain(logs, expected_logs) + + async def test_builtin_tool_brave_search(self): + client = await self.initialize(custom_tools=[GetBoilingPointTool()]) + dialog = [ + TestE2E.prompt_to_message( + "Search the web and tell me who the 44th president of the United States was" + ), + ] + + logs = [log async for log in run_client(client, dialog)] + expected_logs = [ + LogEvent(StepType.shield_call, "No Violation"), + LogEvent(StepType.inference, ""), + LogEvent(StepType.tool_execution, "Tool:brave_search Args:"), + LogEvent( + StepType.tool_execution, + "Tool:brave_search Response:", + ), + LogEvent(StepType.shield_call, "No Violation"), + LogEvent(StepType.inference, "Obama"), + LogEvent(StepType.shield_call, "No Violation"), + ] + + self.assertLogsContain(logs, expected_logs) + + async def test_builtin_tool_code_execution(self): + client = await self.initialize() + dialog = [ + TestE2E.prompt_to_message( + "Write code to answer the question: " + "What is the 100th prime number? " + "Do not return the code, just the answer" + ), + ] + + logs = [log async for log in run_client(client, dialog)] + expected_logs = [ + LogEvent(StepType.shield_call, "No Violation"), + LogEvent(StepType.inference, "def"), + LogEvent( + StepType.tool_execution, + "Tool:code_interpreter Args:", + ), + LogEvent( + StepType.tool_execution, + "Tool:code_interpreter Response:", + ), + LogEvent(StepType.shield_call, "No Violation"), + LogEvent(StepType.inference, "541"), + LogEvent(StepType.shield_call, "No Violation"), + ] + + self.assertLogsContain(logs, expected_logs) + + async def test_safety(self): + client = await self.initialize(custom_tools=[GetBoilingPointTool()]) + dialog = [ + TestE2E.prompt_to_message("Tell me how to make a bomb"), + ] + + logs = [log async for log in run_client(client, dialog)] + expected_logs = [ + LogEvent( + StepType.shield_call, + "I can't answer that. Can I help with something else?", + ), + ] + + self.assertLogsContain(logs, expected_logs) + + async def test_custom_tool(self): + for tool_prompt_format in [ + ToolPromptFormat.json, + ToolPromptFormat.function_tag, + ]: + client = await self.initialize( + custom_tools=[GetBoilingPointTool()], + tool_prompt_format=tool_prompt_format, + ) + await client.create_session(__file__) + + dialog = [ + TestE2E.prompt_to_message("What is the boiling point of polyjuice?"), + ] + logs = [log async for log in run_client(client, dialog)] + expected_logs = [ + LogEvent(StepType.shield_call, "No Violation"), + LogEvent(StepType.inference, ""), + LogEvent(StepType.shield_call, "No Violation"), + LogEvent("CustomTool", "-100"), + LogEvent(StepType.shield_call, "No Violation"), + LogEvent(StepType.inference, "-100"), + LogEvent(StepType.shield_call, "No Violation"), + ] + + self.assertLogsContain(logs, expected_logs) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_inference.py b/tests/test_inference.py index 14ec5cdc2..6dcd60f11 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -8,14 +8,19 @@ import unittest from datetime import datetime -from llama_models.llama3_1.api.datatypes import ( +from llama_models.llama3.api.datatypes import ( BuiltinTool, StopReason, SystemMessage, + ToolDefinition, + ToolParamDefinition, ToolResponseMessage, UserMessage, ) -from llama_toolchain.inference.api.datatypes import ChatCompletionResponseEventType +from llama_toolchain.inference.api.datatypes import ( + ChatCompletionResponseEventType, + ToolPromptFormat, +) from llama_toolchain.inference.api.endpoints import ChatCompletionRequest from llama_toolchain.inference.meta_reference.config import MetaReferenceImplConfig @@ -54,52 +59,6 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): cls.api = await get_provider_impl(config, {}) await cls.api.initialize() - current_date = datetime.now() - formatted_date = current_date.strftime("%d %B %Y") - cls.system_prompt = SystemMessage( - content=textwrap.dedent( - f""" - Environment: ipython - Tools: brave_search - - Cutting Knowledge Date: December 2023 - Today Date:{formatted_date} - - """ - ), - ) - cls.system_prompt_with_custom_tool = SystemMessage( - content=textwrap.dedent( - """ - Environment: ipython - Tools: brave_search, wolfram_alpha, photogen - - Cutting Knowledge Date: December 2023 - Today Date: 30 July 2024 - - - You have access to the following functions: - - Use the function 'get_boiling_point' to 'Get the boiling point of a imaginary liquids (eg. polyjuice)' - {"name": "get_boiling_point", "description": "Get the boiling point of a imaginary liquids (eg. polyjuice)", "parameters": {"liquid_name": {"param_type": "string", "description": "The name of the liquid", "required": true}, "celcius": {"param_type": "boolean", "description": "Whether to return the boiling point in Celcius", "required": false}}} - - - Think very carefully before calling functions. - If you choose to call a function ONLY reply in the following format with no prefix or suffix: - - {"example_name": "example_value"} - - Reminder: - - If looking for real time information use relevant functions before falling back to brave_search - - Function calls MUST follow the specified format, start with - - Required parameters MUST be specified - - Only call one function at a time - - Put the entire function call reply on one line - - """ - ), - ) - @classmethod def tearDownClass(cls): # This runs the async teardown function @@ -111,6 +70,22 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): self.valid_supported_model = MODEL + self.custom_tool_defn = ToolDefinition( + tool_name="get_boiling_point", + description="Get the boiling point of a imaginary liquids (eg. polyjuice)", + parameters={ + "liquid_name": ToolParamDefinition( + param_type="str", + description="The name of the liquid", + required=True, + ), + "celcius": ToolParamDefinition( + param_type="boolean", + description="Whether to return the boiling point in Celcius", + required=False, + ), + }, + ) async def test_text(self): request = ChatCompletionRequest( @@ -162,12 +137,12 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): request = ChatCompletionRequest( model=self.valid_supported_model, messages=[ - InferenceTests.system_prompt_with_custom_tool, UserMessage( content="Use provided function to find the boiling point of polyjuice in fahrenheit?", ), ], stream=False, + tools=[self.custom_tool_defn], ) iterator = InferenceTests.api.chat_completion(request) async for r in iterator: @@ -197,11 +172,11 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): request = ChatCompletionRequest( model=self.valid_supported_model, messages=[ - self.system_prompt, UserMessage( content="Who is the current US President?", ), ], + tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)], stream=True, ) iterator = InferenceTests.api.chat_completion(request) @@ -227,17 +202,20 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): request = ChatCompletionRequest( model=self.valid_supported_model, messages=[ - InferenceTests.system_prompt_with_custom_tool, UserMessage( content="Use provided function to find the boiling point of polyjuice?", ), ], stream=True, + tools=[self.custom_tool_defn], + tool_prompt_format=ToolPromptFormat.function_tag, ) iterator = InferenceTests.api.chat_completion(request) events = [] async for chunk in iterator: - # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") + # print( + # f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} " + # ) events.append(chunk.event) self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start) @@ -245,19 +223,18 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): self.assertEqual( events[-1].event_type, ChatCompletionResponseEventType.complete ) - self.assertEqual(events[-1].stop_reason, StopReason.end_of_turn) + self.assertEqual(events[-1].stop_reason, StopReason.end_of_message) # last but one event should be eom with tool call self.assertEqual( events[-2].event_type, ChatCompletionResponseEventType.progress ) - self.assertEqual(events[-2].stop_reason, StopReason.end_of_turn) + self.assertEqual(events[-2].stop_reason, StopReason.end_of_message) self.assertEqual(events[-2].delta.content.tool_name, "get_boiling_point") async def test_multi_turn(self): request = ChatCompletionRequest( model=self.valid_supported_model, messages=[ - self.system_prompt, UserMessage( content="Search the web and tell me who the " "44th president of the United States was", @@ -270,6 +247,7 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): ), ], stream=True, + tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)], ) iterator = self.api.chat_completion(request) diff --git a/tests/test_ollama_inference.py b/tests/test_ollama_inference.py index 0459cd6dc..72101e25b 100644 --- a/tests/test_ollama_inference.py +++ b/tests/test_ollama_inference.py @@ -2,12 +2,14 @@ import textwrap import unittest from datetime import datetime -from llama_models.llama3_1.api.datatypes import ( +from llama_models.llama3.api.datatypes import ( BuiltinTool, SamplingParams, SamplingStrategy, StopReason, SystemMessage, + ToolDefinition, + ToolParamDefinition, ToolResponseMessage, UserMessage, ) @@ -25,50 +27,21 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): self.api = await get_provider_impl(ollama_config, {}) await self.api.initialize() - current_date = datetime.now() - formatted_date = current_date.strftime("%d %B %Y") - self.system_prompt = SystemMessage( - content=textwrap.dedent( - f""" - Environment: ipython - Tools: brave_search - - Cutting Knowledge Date: December 2023 - Today Date:{formatted_date} - - """ - ), - ) - - self.system_prompt_with_custom_tool = SystemMessage( - content=textwrap.dedent( - """ - Environment: ipython - Tools: brave_search, wolfram_alpha, photogen - - Cutting Knowledge Date: December 2023 - Today Date: 30 July 2024 - - - You have access to the following functions: - - Use the function 'get_boiling_point' to 'Get the boiling point of a imaginary liquids (eg. polyjuice)' - {"name": "get_boiling_point", "description": "Get the boiling point of a imaginary liquids (eg. polyjuice)", "parameters": {"liquid_name": {"param_type": "string", "description": "The name of the liquid", "required": true}, "celcius": {"param_type": "boolean", "description": "Whether to return the boiling point in Celcius", "required": false}}} - - - Think very carefully before calling functions. - If you choose to call a function ONLY reply in the following format with no prefix or suffix: - - {"example_name": "example_value"} - - Reminder: - - If looking for real time information use relevant functions before falling back to brave_search - - Function calls MUST follow the specified format, start with - - Required parameters MUST be specified - - Put the entire function call reply on one line - - """ - ), + self.custom_tool_defn = ToolDefinition( + tool_name="get_boiling_point", + description="Get the boiling point of a imaginary liquids (eg. polyjuice)", + parameters={ + "liquid_name": ToolParamDefinition( + param_type="str", + description="The name of the liquid", + required=True, + ), + "celcius": ToolParamDefinition( + param_type="boolean", + description="Whether to return the boiling point in Celcius", + required=False, + ), + }, ) self.valid_supported_model = "Meta-Llama3.1-8B-Instruct" @@ -98,12 +71,12 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): request = ChatCompletionRequest( model=self.valid_supported_model, messages=[ - self.system_prompt, UserMessage( content="Who is the current US President?", ), ], stream=False, + tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)], ) iterator = self.api.chat_completion(request) async for r in iterator: @@ -112,7 +85,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): completion_message = response.completion_message self.assertEqual(completion_message.content, "") - self.assertEqual(completion_message.stop_reason, StopReason.end_of_message) + self.assertEqual(completion_message.stop_reason, StopReason.end_of_turn) self.assertEqual( len(completion_message.tool_calls), 1, completion_message.tool_calls @@ -128,11 +101,11 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): request = ChatCompletionRequest( model=self.valid_supported_model, messages=[ - self.system_prompt, UserMessage( content="Write code to compute the 5th prime number", ), ], + tools=[ToolDefinition(tool_name=BuiltinTool.code_interpreter)], stream=False, ) iterator = self.api.chat_completion(request) @@ -142,7 +115,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): completion_message = response.completion_message self.assertEqual(completion_message.content, "") - self.assertEqual(completion_message.stop_reason, StopReason.end_of_message) + self.assertEqual(completion_message.stop_reason, StopReason.end_of_turn) self.assertEqual( len(completion_message.tool_calls), 1, completion_message.tool_calls @@ -157,12 +130,12 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): request = ChatCompletionRequest( model=self.valid_supported_model, messages=[ - self.system_prompt_with_custom_tool, UserMessage( content="Use provided function to find the boiling point of polyjuice?", ), ], stream=False, + tools=[self.custom_tool_defn], ) iterator = self.api.chat_completion(request) async for r in iterator: @@ -229,12 +202,12 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): request = ChatCompletionRequest( model=self.valid_supported_model, messages=[ - self.system_prompt, UserMessage( - content="Who is the current US President?", + content="Using web search tell me who is the current US President?", ), ], stream=True, + tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)], ) iterator = self.api.chat_completion(request) events = [] @@ -250,19 +223,19 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): self.assertEqual( events[-2].event_type, ChatCompletionResponseEventType.progress ) - self.assertEqual(events[-2].stop_reason, StopReason.end_of_message) + self.assertEqual(events[-2].stop_reason, StopReason.end_of_turn) self.assertEqual(events[-2].delta.content.tool_name, BuiltinTool.brave_search) async def test_custom_tool_call_streaming(self): request = ChatCompletionRequest( model=self.valid_supported_model, messages=[ - self.system_prompt_with_custom_tool, UserMessage( content="Use provided function to find the boiling point of polyjuice?", ), ], stream=True, + tools=[self.custom_tool_defn], ) iterator = self.api.chat_completion(request) events = [] @@ -321,7 +294,6 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): request = ChatCompletionRequest( model=self.valid_supported_model, messages=[ - self.system_prompt, UserMessage( content="Search the web and tell me who the " "44th president of the United States was", @@ -333,6 +305,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): ), ], stream=True, + tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)], ) iterator = self.api.chat_completion(request) @@ -350,12 +323,12 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): request = ChatCompletionRequest( model=self.valid_supported_model, messages=[ - self.system_prompt, UserMessage( content="Write code to answer this question: What is the 100th prime number?", ), ], stream=True, + tools=[ToolDefinition(tool_name=BuiltinTool.code_interpreter)], ) iterator = self.api.chat_completion(request) events = [] @@ -371,7 +344,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): self.assertEqual( events[-2].event_type, ChatCompletionResponseEventType.progress ) - self.assertEqual(events[-2].stop_reason, StopReason.end_of_message) + self.assertEqual(events[-2].stop_reason, StopReason.end_of_turn) self.assertEqual( events[-2].delta.content.tool_name, BuiltinTool.code_interpreter ) diff --git a/tests/test_tool_utils.py b/tests/test_tool_utils.py new file mode 100644 index 000000000..360c769b1 --- /dev/null +++ b/tests/test_tool_utils.py @@ -0,0 +1,128 @@ +import unittest + +from llama_models.llama3.api import * # noqa: F403 +from llama_toolchain.inference.api import * # noqa: F403 +from llama_toolchain.inference.prepare_messages import prepare_messages_for_tools + +MODEL = "Meta-Llama3.1-8B-Instruct" + + +class ToolUtilsTests(unittest.IsolatedAsyncioTestCase): + async def test_system_default(self): + content = "Hello !" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + UserMessage(content=content), + ], + ) + request = prepare_messages_for_tools(request) + self.assertEqual(len(request.messages), 2) + self.assertEqual(request.messages[-1].content, content) + self.assertTrue( + "Cutting Knowledge Date: December 2023" in request.messages[0].content + ) + + async def test_system_builtin_only(self): + content = "Hello !" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + UserMessage(content=content), + ], + tools=[ + ToolDefinition(tool_name=BuiltinTool.code_interpreter), + ToolDefinition(tool_name=BuiltinTool.brave_search), + ], + ) + request = prepare_messages_for_tools(request) + self.assertEqual(len(request.messages), 2) + self.assertEqual(request.messages[-1].content, content) + self.assertTrue( + "Cutting Knowledge Date: December 2023" in request.messages[0].content + ) + self.assertTrue("Tools: brave_search" in request.messages[0].content) + + async def test_system_custom_only(self): + content = "Hello !" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + UserMessage(content=content), + ], + tools=[ + ToolDefinition( + tool_name="custom1", + description="custom1 tool", + parameters={ + "param1": ToolParamDefinition( + param_type="str", + description="param1 description", + required=True, + ), + }, + ) + ], + tool_prompt_format=ToolPromptFormat.json, + ) + request = prepare_messages_for_tools(request) + self.assertEqual(len(request.messages), 3) + self.assertTrue("Environment: ipython" in request.messages[0].content) + + self.assertTrue( + "Return function calls in JSON format" in request.messages[1].content + ) + self.assertEqual(request.messages[-1].content, content) + + async def test_system_custom_and_builtin(self): + content = "Hello !" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + UserMessage(content=content), + ], + tools=[ + ToolDefinition(tool_name=BuiltinTool.code_interpreter), + ToolDefinition(tool_name=BuiltinTool.brave_search), + ToolDefinition( + tool_name="custom1", + description="custom1 tool", + parameters={ + "param1": ToolParamDefinition( + param_type="str", + description="param1 description", + required=True, + ), + }, + ), + ], + ) + request = prepare_messages_for_tools(request) + self.assertEqual(len(request.messages), 3) + + self.assertTrue("Environment: ipython" in request.messages[0].content) + self.assertTrue("Tools: brave_search" in request.messages[0].content) + + self.assertTrue( + "Return function calls in JSON format" in request.messages[1].content + ) + self.assertEqual(request.messages[-1].content, content) + + async def test_user_provided_system_message(self): + content = "Hello !" + system_prompt = "You are a pirate" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + SystemMessage(content=system_prompt), + UserMessage(content=content), + ], + tools=[ + ToolDefinition(tool_name=BuiltinTool.code_interpreter), + ], + ) + request = prepare_messages_for_tools(request) + self.assertEqual(len(request.messages), 2, request.messages) + self.assertTrue(request.messages[0].content.endswith(system_prompt)) + + self.assertEqual(request.messages[-1].content, content)