From ab8193c88c2cae92d084e83c61a7ec4d9b4c0def Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Fri, 23 Aug 2024 14:21:12 -0700 Subject: [PATCH] use templates for generating system prompts --- .../common/prompt_templates/base.py | 26 +++ .../common/prompt_templates/system_prompts.py | 206 +++++++++++++++++ .../inference/meta_reference/inference.py | 6 +- llama_toolchain/inference/ollama/ollama.py | 8 +- llama_toolchain/inference/prepare_messages.py | 215 ++++-------------- tests/test_ollama_inference.py | 6 +- ...tool_utils.py => test_prepare_messages.py} | 60 +++-- tests/test_prompt_templates.py | 101 ++++++++ 8 files changed, 410 insertions(+), 218 deletions(-) create mode 100644 llama_toolchain/common/prompt_templates/base.py create mode 100644 llama_toolchain/common/prompt_templates/system_prompts.py rename tests/{test_tool_utils.py => test_prepare_messages.py} (61%) create mode 100644 tests/test_prompt_templates.py diff --git a/llama_toolchain/common/prompt_templates/base.py b/llama_toolchain/common/prompt_templates/base.py new file mode 100644 index 000000000..de229bcb2 --- /dev/null +++ b/llama_toolchain/common/prompt_templates/base.py @@ -0,0 +1,26 @@ +from dataclasses import dataclass +from typing import Any, Dict, List + +from jinja2 import Template + + +@dataclass +class PromptTemplate: + template: str + data: Dict[str, Any] + + def render(self): + template = Template(self.template) + return template.render(self.data) + + +class PromptTemplateGeneratorBase: + """ + Base class for prompt template generators. + """ + + def gen(self, *args, **kwargs) -> PromptTemplate: + raise NotImplementedError() + + def data_examples(self) -> List[Any]: + raise NotImplementedError() diff --git a/llama_toolchain/common/prompt_templates/system_prompts.py b/llama_toolchain/common/prompt_templates/system_prompts.py new file mode 100644 index 000000000..aeabb116e --- /dev/null +++ b/llama_toolchain/common/prompt_templates/system_prompts.py @@ -0,0 +1,206 @@ +import textwrap +from datetime import datetime +from typing import Any, Dict, List + +from llama_models.llama3.api.datatypes import ( + BuiltinTool, + ToolDefinition, + ToolParamDefinition, +) + +from .base import PromptTemplate, PromptTemplateGeneratorBase + + +class SystemDefaultGenerator(PromptTemplateGeneratorBase): + + def gen(self, *args, **kwargs) -> PromptTemplate: + template_str = textwrap.dedent( + """ + Cutting Knowledge Date: December 2023 + Today Date: {{ today }} + """ + ) + return PromptTemplate( + template_str.lstrip("\n"), + {"today": datetime.now().strftime("%d %B %Y")}, + ) + + def data_examples(self) -> List[Any]: + return [None] + + +class BuiltinToolGenerator(PromptTemplateGeneratorBase): + + def _tool_breakdown(self, tools: List[ToolDefinition]): + builtin_tools, custom_tools = [], [] + for dfn in tools: + if isinstance(dfn.tool_name, BuiltinTool): + builtin_tools.append(dfn) + else: + custom_tools.append(dfn) + + return builtin_tools, custom_tools + + def gen(self, tools: List[ToolDefinition]) -> PromptTemplate: + builtin_tools, custom_tools = self._tool_breakdown(tools) + data = [] + template_str = textwrap.dedent( + """ + {% if builtin_tools or custom_tools -%} + Environment: ipython + {% endif -%} + {% set builtin_tools = builtin_tools | reject('equalto', 'code_interpreter') | list -%} + {% if builtin_tools -%} + Tools: {{ builtin_tools | join(", ") | trim -}} + {% endif %} + """ + ) + return PromptTemplate( + template_str.lstrip("\n"), + { + "builtin_tools": [t.tool_name.value for t in builtin_tools], + "custom_tools": custom_tools, + }, + ) + + def data_examples(self) -> List[List[ToolDefinition]]: + return [ + # builtin tools + [ + ToolDefinition(tool_name=BuiltinTool.code_interpreter), + ToolDefinition(tool_name=BuiltinTool.brave_search), + ToolDefinition(tool_name=BuiltinTool.wolfram_alpha), + ], + # only code interpretor + [ + ToolDefinition(tool_name=BuiltinTool.code_interpreter), + ], + ] + + +class JsonCustomToolGenerator(PromptTemplateGeneratorBase): + + def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: + template_str = textwrap.dedent( + """ + Answer the user's question by making use of the following functions if needed. + If none of the function can be used, please say so. + Here is a list of functions in JSON format: + {% for t in custom_tools -%} + {# manually setting up JSON because jinja sorts keys in unexpected ways -#} + {%- set tname = t.tool_name -%} + {%- set tdesc = t.description -%} + {%- set tparams = t.parameters -%} + {%- set required_params = [] -%} + {%- for name, param in tparams.items() if param.required == true -%} + {%- set _ = required_params.append(name) -%} + {%- endfor -%} + { + "type": "function", + "function": { + "name": "{{tname}}", + "description": "{{tdesc}}", + "parameters": { + "type": "object", + "properties": [ + {%- for name, param in tparams.items() %} + { + "{{name}}": { + "type": "object", + "description": "{{param.description}}" + } + }{% if not loop.last %},{% endif %} + {%- endfor %} + ], + "required": {{ required_params | tojson }} + } + } + } + {% endfor %} + Return function calls in JSON format. + """ + ) + + return PromptTemplate( + template_str.lstrip("\n"), + {"custom_tools": [t.model_dump() for t in custom_tools]}, + ) + + def data_examples(self) -> List[List[ToolDefinition]]: + return [ + [ + ToolDefinition( + tool_name="trending_songs", + description="Returns the trending songs on a Music site", + parameters={ + "n": ToolParamDefinition( + param_type="int", + description="The number of songs to return", + required=True, + ), + "genre": ToolParamDefinition( + param_type="str", + description="The genre of the songs to return", + required=False, + ), + }, + ), + ] + ] + + +class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase): + + def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: + template_str = textwrap.dedent( + """ + You have access to the following functions: + + {% for t in custom_tools %} + {#- manually setting up JSON because jinja sorts keys in unexpected ways -#} + {%- set tname = t.tool_name -%} + {%- set tdesc = t.description -%} + {%- set tparams = t.parameters | tojson -%} + Use the function '{{ tname }}' to '{{ tdesc }}': + {"name": "{{tname}}", "description": "{{tdesc}}", "parameters": {{tparams}}} + + {% endfor -%} + Think very carefully before calling functions. + If a you choose to call a function ONLY reply in the following format with no prefix or suffix: + + {"example_name": "example_value"} + + Reminder: + - If looking for real time information use relevant functions before falling back to brave_search + - Function calls MUST follow the specified format, start with + - Required parameters MUST be specified + - Only call one function at a time + - Put the entire function call reply on one line + """ + ) + return PromptTemplate( + template_str.lstrip("\n"), + {"custom_tools": [t.model_dump() for t in custom_tools]}, + ) + + def data_examples(self) -> List[List[ToolDefinition]]: + return [ + [ + ToolDefinition( + tool_name="trending_songs", + description="Returns the trending songs on a Music site", + parameters={ + "n": ToolParamDefinition( + param_type="int", + description="The number of songs to return", + required=True, + ), + "genre": ToolParamDefinition( + param_type="str", + description="The genre of the songs to return", + required=False, + ), + }, + ), + ] + ] diff --git a/llama_toolchain/inference/meta_reference/inference.py b/llama_toolchain/inference/meta_reference/inference.py index dc674a25b..87ffc5226 100644 --- a/llama_toolchain/inference/meta_reference/inference.py +++ b/llama_toolchain/inference/meta_reference/inference.py @@ -22,7 +22,7 @@ from llama_toolchain.inference.api import ( ToolCallDelta, ToolCallParseStatus, ) -from llama_toolchain.inference.prepare_messages import prepare_messages_for_tools +from llama_toolchain.inference.prepare_messages import prepare_messages from .config import MetaReferenceImplConfig from .model_parallel import LlamaModelParallelGenerator @@ -67,7 +67,7 @@ class MetaReferenceInferenceImpl(Inference): ) -> AsyncIterator[ Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse] ]: - request = prepare_messages_for_tools(request) + messages = prepare_messages(request) model = resolve_model(request.model) if model is None: raise RuntimeError( @@ -99,7 +99,7 @@ class MetaReferenceInferenceImpl(Inference): ipython = False for token_result in self.generator.chat_completion( - messages=request.messages, + messages=messages, temperature=request.sampling_params.temperature, top_p=request.sampling_params.top_p, max_gen_len=request.sampling_params.max_tokens, diff --git a/llama_toolchain/inference/ollama/ollama.py b/llama_toolchain/inference/ollama/ollama.py index 8bfd38a71..235cb20cc 100644 --- a/llama_toolchain/inference/ollama/ollama.py +++ b/llama_toolchain/inference/ollama/ollama.py @@ -32,7 +32,7 @@ from llama_toolchain.inference.api import ( ToolCallDelta, ToolCallParseStatus, ) -from llama_toolchain.inference.prepare_messages import prepare_messages_for_tools +from llama_toolchain.inference.prepare_messages import prepare_messages from .config import OllamaImplConfig # TODO: Eventually this will move to the llama cli model list command @@ -111,7 +111,7 @@ class OllamaInference(Inference): return options async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: - request = prepare_messages_for_tools(request) + messages = prepare_messages(request) # accumulate sampling params and other options to pass to ollama options = self.get_ollama_chat_options(request) ollama_model = self.resolve_ollama_model(request.model) @@ -133,7 +133,7 @@ class OllamaInference(Inference): if not request.stream: r = await self.client.chat( model=ollama_model, - messages=self._messages_to_ollama_messages(request.messages), + messages=self._messages_to_ollama_messages(messages), stream=False, options=options, ) @@ -161,7 +161,7 @@ class OllamaInference(Inference): ) stream = await self.client.chat( model=ollama_model, - messages=self._messages_to_ollama_messages(request.messages), + messages=self._messages_to_ollama_messages(messages), stream=True, options=options, ) diff --git a/llama_toolchain/inference/prepare_messages.py b/llama_toolchain/inference/prepare_messages.py index e23bbbe8f..83aff57f9 100644 --- a/llama_toolchain/inference/prepare_messages.py +++ b/llama_toolchain/inference/prepare_messages.py @@ -1,203 +1,66 @@ -import json -import os import textwrap -from datetime import datetime from llama_toolchain.inference.api import * # noqa: F403 -from llama_toolchain.tools.builtin import ( - BraveSearchTool, - CodeInterpreterTool, - PhotogenTool, - WolframAlphaTool, +from llama_toolchain.common.prompt_templates.system_prompts import ( + BuiltinToolGenerator, + FunctionTagCustomToolGenerator, + JsonCustomToolGenerator, + SystemDefaultGenerator, ) -def tool_breakdown(tools: List[ToolDefinition]) -> str: - builtin_tools, custom_tools = [], [] - for dfn in tools: - if isinstance(dfn.tool_name, BuiltinTool): - builtin_tools.append(dfn) - else: - custom_tools.append(dfn) +def prepare_messages(request: ChatCompletionRequest) -> List[Message]: - return builtin_tools, custom_tools - - -def prepare_messages_for_tools(request: ChatCompletionRequest) -> ChatCompletionRequest: - """This functions takes a ChatCompletionRequest and returns an augmented request. - The request's messages are augmented to update the system message - corresponding to the tool definitions provided in the request. - """ assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported" existing_messages = request.messages - existing_system_message = None if existing_messages[0].role == Role.system.value: existing_system_message = existing_messages.pop(0) - builtin_tools, custom_tools = tool_breakdown(request.tools) + assert ( + existing_messages[0].role != Role.system.value + ), "Should only have 1 system message" messages = [] - content = "" - if builtin_tools or custom_tools: - content += "Environment: ipython\n" - if builtin_tools: - tool_str = ", ".join( - [ - t.tool_name.value - for t in builtin_tools - if t.tool_name != BuiltinTool.code_interpreter - ] - ) - if tool_str: - content += f"Tools: {tool_str}\n" + default_gen = SystemDefaultGenerator() + default_template = default_gen.gen() - current_date = datetime.now() - formatted_date = current_date.strftime("%d %B %Y") - date_str = textwrap.dedent( - f""" - Cutting Knowledge Date: December 2023 - Today Date: {formatted_date} - """ - ) - content += date_str.lstrip("\n") + sys_content = "" + + tool_template = None + if request.tools: + tool_gen = BuiltinToolGenerator() + tool_template = tool_gen.gen(request.tools) + + sys_content += tool_template.render() + sys_content += "\n" + + sys_content += default_template.render() if existing_system_message: - content += "\n" - content += existing_system_message.content + sys_content += "\n" + sys_content += existing_system_message.content - messages.append(SystemMessage(content=content)) + messages.append(SystemMessage(content=sys_content)) - if custom_tools: - if request.tool_prompt_format == ToolPromptFormat.function_tag: - text = prompt_for_function_tag(custom_tools) - messages.append(UserMessage(content=text)) - elif request.tool_prompt_format == ToolPromptFormat.json: - text = prompt_for_json(custom_tools) - messages.append(UserMessage(content=text)) + has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools) + if has_custom_tools: + if request.tool_prompt_format == ToolPromptFormat.json: + tool_gen = JsonCustomToolGenerator() + elif request.tool_prompt_format == ToolPromptFormat.function_tag: + tool_gen = FunctionTagCustomToolGenerator() else: - raise NotImplementedError( - f"Tool prompt format {tool_prompt_format} is not supported" + raise ValueError( + f"Non supported ToolPromptFormat {request.tool_prompt_format}" ) + custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)] + custom_template = tool_gen.gen(custom_tools) + messages.append(UserMessage(content=custom_template.render())) + + # Add back existing messages from the request messages += existing_messages - request.messages = messages - return request - -def prompt_for_json(custom_tools: List[ToolDefinition]) -> str: - tool_defs = "\n".join( - translate_custom_tool_definition_to_json(t) for t in custom_tools - ) - content = textwrap.dedent( - """ - Answer the user's question by making use of the following functions if needed. - If none of the function can be used, please say so. - Here is a list of functions in JSON format: - {tool_defs} - - Return function calls in JSON format. - """ - ) - content = content.lstrip("\n").format(tool_defs=tool_defs) - return content - - -def prompt_for_function_tag(custom_tools: List[ToolDefinition]) -> str: - custom_tool_params = "" - for t in custom_tools: - custom_tool_params += get_instruction_string(t) + "\n" - custom_tool_params += get_parameters_string(t) + "\n\n" - - content = textwrap.dedent( - """ - You have access to the following functions: - - {custom_tool_params} - Think very carefully before calling functions. - If you choose to call a function ONLY reply in the following format with no prefix or suffix: - - {{"example_name": "example_value"}} - - Reminder: - - If looking for real time information use relevant functions before falling back to brave_search - - Function calls MUST follow the specified format, start with - - Required parameters MUST be specified - - Only call one function at a time - - Put the entire function call reply on one line - """ - ) - - return content.lstrip("\n").format(custom_tool_params=custom_tool_params) - - -def get_instruction_string(custom_tool_definition) -> str: - return f"Use the function '{custom_tool_definition.tool_name}' to '{custom_tool_definition.description}'" - - -def get_parameters_string(custom_tool_definition) -> str: - return json.dumps( - { - "name": custom_tool_definition.tool_name, - "description": custom_tool_definition.description, - "parameters": { - name: definition.__dict__ - for name, definition in custom_tool_definition.parameters.items() - }, - } - ) - - -def translate_custom_tool_definition_to_json(tool_def): - """Translates ToolDefinition to json as expected by model - eg. output for a function - { - "type": "function", - "function": { - "name": "conv_int", - "description": "Convert serialized fract24 integer into int value.", - "parameters": { - "type": "object", - "properties": [ - { - "data": { - "type": "object", - "description": "" - } - } - ], - "required": ["data"] - } - } - } - """ - assert isinstance(tool_def.tool_name, str) - func_def = {"type": "function", "function": {}} - func_def["function"]["name"] = tool_def.tool_name - func_def["function"]["description"] = tool_def.description or "" - if tool_def.parameters: - required = [] - properties = [] - for p_name, p_def in tool_def.parameters.items(): - properties.append( - { - p_name: { - # TODO: see if this should not always be object - "type": "object", - "description": p_def.description or "", - } - } - ) - if p_def.required: - required.append(p_name) - func_def["function"]["parameters"] = { - "type": "object", - "properties": properties, - "required": required, - } - else: - func_def["function"]["parameters"] = {} - - return json.dumps(func_def, indent=4) + return messages diff --git a/tests/test_ollama_inference.py b/tests/test_ollama_inference.py index 72101e25b..5ff1b94f9 100644 --- a/tests/test_ollama_inference.py +++ b/tests/test_ollama_inference.py @@ -13,7 +13,10 @@ from llama_models.llama3.api.datatypes import ( ToolResponseMessage, UserMessage, ) -from llama_toolchain.inference.api.datatypes import ChatCompletionResponseEventType +from llama_toolchain.inference.api.datatypes import ( + ChatCompletionResponseEventType, + ToolPromptFormat, +) from llama_toolchain.inference.api.endpoints import ChatCompletionRequest from llama_toolchain.inference.ollama.config import OllamaImplConfig from llama_toolchain.inference.ollama.ollama import get_provider_impl @@ -236,6 +239,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): ], stream=True, tools=[self.custom_tool_defn], + tool_prompt_format=ToolPromptFormat.function_tag, ) iterator = self.api.chat_completion(request) events = [] diff --git a/tests/test_tool_utils.py b/tests/test_prepare_messages.py similarity index 61% rename from tests/test_tool_utils.py rename to tests/test_prepare_messages.py index 360c769b1..49624b04d 100644 --- a/tests/test_tool_utils.py +++ b/tests/test_prepare_messages.py @@ -2,12 +2,12 @@ import unittest from llama_models.llama3.api import * # noqa: F403 from llama_toolchain.inference.api import * # noqa: F403 -from llama_toolchain.inference.prepare_messages import prepare_messages_for_tools +from llama_toolchain.inference.prepare_messages import prepare_messages MODEL = "Meta-Llama3.1-8B-Instruct" -class ToolUtilsTests(unittest.IsolatedAsyncioTestCase): +class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): async def test_system_default(self): content = "Hello !" request = ChatCompletionRequest( @@ -16,12 +16,10 @@ class ToolUtilsTests(unittest.IsolatedAsyncioTestCase): UserMessage(content=content), ], ) - request = prepare_messages_for_tools(request) - self.assertEqual(len(request.messages), 2) - self.assertEqual(request.messages[-1].content, content) - self.assertTrue( - "Cutting Knowledge Date: December 2023" in request.messages[0].content - ) + messages = prepare_messages(request) + self.assertEqual(len(messages), 2) + self.assertEqual(messages[-1].content, content) + self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) async def test_system_builtin_only(self): content = "Hello !" @@ -35,13 +33,11 @@ class ToolUtilsTests(unittest.IsolatedAsyncioTestCase): ToolDefinition(tool_name=BuiltinTool.brave_search), ], ) - request = prepare_messages_for_tools(request) - self.assertEqual(len(request.messages), 2) - self.assertEqual(request.messages[-1].content, content) - self.assertTrue( - "Cutting Knowledge Date: December 2023" in request.messages[0].content - ) - self.assertTrue("Tools: brave_search" in request.messages[0].content) + messages = prepare_messages(request) + self.assertEqual(len(messages), 2) + self.assertEqual(messages[-1].content, content) + self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) + self.assertTrue("Tools: brave_search" in messages[0].content) async def test_system_custom_only(self): content = "Hello !" @@ -65,14 +61,12 @@ class ToolUtilsTests(unittest.IsolatedAsyncioTestCase): ], tool_prompt_format=ToolPromptFormat.json, ) - request = prepare_messages_for_tools(request) - self.assertEqual(len(request.messages), 3) - self.assertTrue("Environment: ipython" in request.messages[0].content) + messages = prepare_messages(request) + self.assertEqual(len(messages), 3) + self.assertTrue("Environment: ipython" in messages[0].content) - self.assertTrue( - "Return function calls in JSON format" in request.messages[1].content - ) - self.assertEqual(request.messages[-1].content, content) + self.assertTrue("Return function calls in JSON format" in messages[1].content) + self.assertEqual(messages[-1].content, content) async def test_system_custom_and_builtin(self): content = "Hello !" @@ -97,16 +91,14 @@ class ToolUtilsTests(unittest.IsolatedAsyncioTestCase): ), ], ) - request = prepare_messages_for_tools(request) - self.assertEqual(len(request.messages), 3) + messages = prepare_messages(request) + self.assertEqual(len(messages), 3) - self.assertTrue("Environment: ipython" in request.messages[0].content) - self.assertTrue("Tools: brave_search" in request.messages[0].content) + self.assertTrue("Environment: ipython" in messages[0].content) + self.assertTrue("Tools: brave_search" in messages[0].content) - self.assertTrue( - "Return function calls in JSON format" in request.messages[1].content - ) - self.assertEqual(request.messages[-1].content, content) + self.assertTrue("Return function calls in JSON format" in messages[1].content) + self.assertEqual(messages[-1].content, content) async def test_user_provided_system_message(self): content = "Hello !" @@ -121,8 +113,8 @@ class ToolUtilsTests(unittest.IsolatedAsyncioTestCase): ToolDefinition(tool_name=BuiltinTool.code_interpreter), ], ) - request = prepare_messages_for_tools(request) - self.assertEqual(len(request.messages), 2, request.messages) - self.assertTrue(request.messages[0].content.endswith(system_prompt)) + messages = prepare_messages(request) + self.assertEqual(len(messages), 2, messages) + self.assertTrue(messages[0].content.endswith(system_prompt)) - self.assertEqual(request.messages[-1].content, content) + self.assertEqual(messages[-1].content, content) diff --git a/tests/test_prompt_templates.py b/tests/test_prompt_templates.py new file mode 100644 index 000000000..94825e327 --- /dev/null +++ b/tests/test_prompt_templates.py @@ -0,0 +1,101 @@ +import textwrap +import unittest +from datetime import datetime + +from llama_toolchain.common.prompt_templates.system_prompts import ( + BuiltinToolGenerator, + FunctionTagCustomToolGenerator, + JsonCustomToolGenerator, + SystemDefaultGenerator, +) + + +class PromptTemplateTests(unittest.TestCase): + + def check_generator_output(self, generator, expected_text): + example = generator.data_examples()[0] + + pt = generator.gen(example) + text = pt.render() + # print(text) # debugging + self.assertEqual(text, expected_text) + + def test_system_default(self): + generator = SystemDefaultGenerator() + today = datetime.now().strftime("%d %B %Y") + expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}" + self.check_generator_output(generator, expected_text) + + def test_system_builtin_only(self): + generator = BuiltinToolGenerator() + expected_text = textwrap.dedent( + """ + Environment: ipython + Tools: brave_search, wolfram_alpha + """ + ) + self.check_generator_output(generator, expected_text.strip("\n")) + + def test_system_custom_only(self): + self.maxDiff = None + generator = JsonCustomToolGenerator() + expected_text = textwrap.dedent( + """ + Answer the user's question by making use of the following functions if needed. + If none of the function can be used, please say so. + Here is a list of functions in JSON format: + { + "type": "function", + "function": { + "name": "trending_songs", + "description": "Returns the trending songs on a Music site", + "parameters": { + "type": "object", + "properties": [ + { + "n": { + "type": "object", + "description": "The number of songs to return" + } + }, + { + "genre": { + "type": "object", + "description": "The genre of the songs to return" + } + } + ], + "required": ["n"] + } + } + } + + Return function calls in JSON format. + """ + ) + self.check_generator_output(generator, expected_text.strip("\n")) + + def test_system_custom_function_tag(self): + self.maxDiff = None + generator = FunctionTagCustomToolGenerator() + expected_text = textwrap.dedent( + """ + You have access to the following functions: + + Use the function 'trending_songs' to 'Returns the trending songs on a Music site': + {"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}} + + Think very carefully before calling functions. + If a you choose to call a function ONLY reply in the following format with no prefix or suffix: + + {"example_name": "example_value"} + + Reminder: + - If looking for real time information use relevant functions before falling back to brave_search + - Function calls MUST follow the specified format, start with + - Required parameters MUST be specified + - Only call one function at a time + - Put the entire function call reply on one line + """ + ) + self.check_generator_output(generator, expected_text.strip("\n"))