From 5655266d58ec7de442023db51238c38bfac87d79 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Fri, 23 Aug 2024 14:58:52 -0700 Subject: [PATCH] Moved ToolPromptFormat and jinja templates to llama_models.llama3.api --- llama_toolchain/agentic_system/client.py | 2 +- .../meta_reference/agent_instance.py | 4 +- llama_toolchain/agentic_system/utils.py | 8 +- .../common/prompt_templates/base.py | 26 --- .../common/prompt_templates/system_prompts.py | 206 ------------------ llama_toolchain/inference/api/datatypes.py | 35 --- llama_toolchain/inference/api/endpoints.py | 6 +- .../inference/meta_reference/config.py | 4 +- llama_toolchain/inference/prepare_messages.py | 3 +- tests/test_e2e.py | 2 +- tests/test_inference.py | 6 +- tests/test_ollama_inference.py | 6 +- tests/test_prompt_templates.py | 101 --------- 13 files changed, 21 insertions(+), 388 deletions(-) delete mode 100644 llama_toolchain/common/prompt_templates/base.py delete mode 100644 llama_toolchain/common/prompt_templates/system_prompts.py delete mode 100644 tests/test_prompt_templates.py diff --git a/llama_toolchain/agentic_system/client.py b/llama_toolchain/agentic_system/client.py index 56428c425..690a8499b 100644 --- a/llama_toolchain/agentic_system/client.py +++ b/llama_toolchain/agentic_system/client.py @@ -17,6 +17,7 @@ from llama_models.llama3.api.datatypes import ( BuiltinTool, SamplingParams, ToolParamDefinition, + ToolPromptFormat, UserMessage, ) from termcolor import cprint @@ -32,7 +33,6 @@ from .api import ( AgenticSystemToolDefinition, AgenticSystemTurnCreateRequest, AgenticSystemTurnResponseStreamChunk, - ToolPromptFormat, ) diff --git a/llama_toolchain/agentic_system/meta_reference/agent_instance.py b/llama_toolchain/agentic_system/meta_reference/agent_instance.py index 5de17d7b9..2c769a5e1 100644 --- a/llama_toolchain/agentic_system/meta_reference/agent_instance.py +++ b/llama_toolchain/agentic_system/meta_reference/agent_instance.py @@ -10,6 +10,8 @@ import uuid from datetime import datetime from typing import AsyncGenerator, List, Optional +from llama_models.llama3.api.datatypes import ToolPromptFormat + from termcolor import cprint from llama_toolchain.agentic_system.api.datatypes import ( @@ -26,12 +28,10 @@ from llama_toolchain.agentic_system.api.datatypes import ( ShieldCallStep, StepType, ToolExecutionStep, - ToolPromptFormat, Turn, ) from llama_toolchain.inference.api import ChatCompletionRequest, Inference - from llama_toolchain.inference.api.datatypes import ( Attachment, BuiltinTool, diff --git a/llama_toolchain/agentic_system/utils.py b/llama_toolchain/agentic_system/utils.py index b2ba4fec8..73fe9f918 100644 --- a/llama_toolchain/agentic_system/utils.py +++ b/llama_toolchain/agentic_system/utils.py @@ -7,7 +7,12 @@ import uuid from typing import Any, List, Optional -from llama_models.llama3.api.datatypes import BuiltinTool, Message, SamplingParams +from llama_models.llama3.api.datatypes import ( + BuiltinTool, + Message, + SamplingParams, + ToolPromptFormat, +) from llama_toolchain.agentic_system.api import ( AgenticSystemCreateRequest, @@ -15,7 +20,6 @@ from llama_toolchain.agentic_system.api import ( AgenticSystemSessionCreateRequest, AgenticSystemToolDefinition, ) -from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat from llama_toolchain.agentic_system.client import AgenticSystemClient from llama_toolchain.agentic_system.meta_reference.execute_with_custom_tools import ( diff --git a/llama_toolchain/common/prompt_templates/base.py b/llama_toolchain/common/prompt_templates/base.py deleted file mode 100644 index de229bcb2..000000000 --- a/llama_toolchain/common/prompt_templates/base.py +++ /dev/null @@ -1,26 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Dict, List - -from jinja2 import Template - - -@dataclass -class PromptTemplate: - template: str - data: Dict[str, Any] - - def render(self): - template = Template(self.template) - return template.render(self.data) - - -class PromptTemplateGeneratorBase: - """ - Base class for prompt template generators. - """ - - def gen(self, *args, **kwargs) -> PromptTemplate: - raise NotImplementedError() - - def data_examples(self) -> List[Any]: - raise NotImplementedError() diff --git a/llama_toolchain/common/prompt_templates/system_prompts.py b/llama_toolchain/common/prompt_templates/system_prompts.py deleted file mode 100644 index aeabb116e..000000000 --- a/llama_toolchain/common/prompt_templates/system_prompts.py +++ /dev/null @@ -1,206 +0,0 @@ -import textwrap -from datetime import datetime -from typing import Any, Dict, List - -from llama_models.llama3.api.datatypes import ( - BuiltinTool, - ToolDefinition, - ToolParamDefinition, -) - -from .base import PromptTemplate, PromptTemplateGeneratorBase - - -class SystemDefaultGenerator(PromptTemplateGeneratorBase): - - def gen(self, *args, **kwargs) -> PromptTemplate: - template_str = textwrap.dedent( - """ - Cutting Knowledge Date: December 2023 - Today Date: {{ today }} - """ - ) - return PromptTemplate( - template_str.lstrip("\n"), - {"today": datetime.now().strftime("%d %B %Y")}, - ) - - def data_examples(self) -> List[Any]: - return [None] - - -class BuiltinToolGenerator(PromptTemplateGeneratorBase): - - def _tool_breakdown(self, tools: List[ToolDefinition]): - builtin_tools, custom_tools = [], [] - for dfn in tools: - if isinstance(dfn.tool_name, BuiltinTool): - builtin_tools.append(dfn) - else: - custom_tools.append(dfn) - - return builtin_tools, custom_tools - - def gen(self, tools: List[ToolDefinition]) -> PromptTemplate: - builtin_tools, custom_tools = self._tool_breakdown(tools) - data = [] - template_str = textwrap.dedent( - """ - {% if builtin_tools or custom_tools -%} - Environment: ipython - {% endif -%} - {% set builtin_tools = builtin_tools | reject('equalto', 'code_interpreter') | list -%} - {% if builtin_tools -%} - Tools: {{ builtin_tools | join(", ") | trim -}} - {% endif %} - """ - ) - return PromptTemplate( - template_str.lstrip("\n"), - { - "builtin_tools": [t.tool_name.value for t in builtin_tools], - "custom_tools": custom_tools, - }, - ) - - def data_examples(self) -> List[List[ToolDefinition]]: - return [ - # builtin tools - [ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition(tool_name=BuiltinTool.brave_search), - ToolDefinition(tool_name=BuiltinTool.wolfram_alpha), - ], - # only code interpretor - [ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ], - ] - - -class JsonCustomToolGenerator(PromptTemplateGeneratorBase): - - def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: - template_str = textwrap.dedent( - """ - Answer the user's question by making use of the following functions if needed. - If none of the function can be used, please say so. - Here is a list of functions in JSON format: - {% for t in custom_tools -%} - {# manually setting up JSON because jinja sorts keys in unexpected ways -#} - {%- set tname = t.tool_name -%} - {%- set tdesc = t.description -%} - {%- set tparams = t.parameters -%} - {%- set required_params = [] -%} - {%- for name, param in tparams.items() if param.required == true -%} - {%- set _ = required_params.append(name) -%} - {%- endfor -%} - { - "type": "function", - "function": { - "name": "{{tname}}", - "description": "{{tdesc}}", - "parameters": { - "type": "object", - "properties": [ - {%- for name, param in tparams.items() %} - { - "{{name}}": { - "type": "object", - "description": "{{param.description}}" - } - }{% if not loop.last %},{% endif %} - {%- endfor %} - ], - "required": {{ required_params | tojson }} - } - } - } - {% endfor %} - Return function calls in JSON format. - """ - ) - - return PromptTemplate( - template_str.lstrip("\n"), - {"custom_tools": [t.model_dump() for t in custom_tools]}, - ) - - def data_examples(self) -> List[List[ToolDefinition]]: - return [ - [ - ToolDefinition( - tool_name="trending_songs", - description="Returns the trending songs on a Music site", - parameters={ - "n": ToolParamDefinition( - param_type="int", - description="The number of songs to return", - required=True, - ), - "genre": ToolParamDefinition( - param_type="str", - description="The genre of the songs to return", - required=False, - ), - }, - ), - ] - ] - - -class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase): - - def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: - template_str = textwrap.dedent( - """ - You have access to the following functions: - - {% for t in custom_tools %} - {#- manually setting up JSON because jinja sorts keys in unexpected ways -#} - {%- set tname = t.tool_name -%} - {%- set tdesc = t.description -%} - {%- set tparams = t.parameters | tojson -%} - Use the function '{{ tname }}' to '{{ tdesc }}': - {"name": "{{tname}}", "description": "{{tdesc}}", "parameters": {{tparams}}} - - {% endfor -%} - Think very carefully before calling functions. - If a you choose to call a function ONLY reply in the following format with no prefix or suffix: - - {"example_name": "example_value"} - - Reminder: - - If looking for real time information use relevant functions before falling back to brave_search - - Function calls MUST follow the specified format, start with - - Required parameters MUST be specified - - Only call one function at a time - - Put the entire function call reply on one line - """ - ) - return PromptTemplate( - template_str.lstrip("\n"), - {"custom_tools": [t.model_dump() for t in custom_tools]}, - ) - - def data_examples(self) -> List[List[ToolDefinition]]: - return [ - [ - ToolDefinition( - tool_name="trending_songs", - description="Returns the trending songs on a Music site", - parameters={ - "n": ToolParamDefinition( - param_type="int", - description="The number of songs to return", - required=True, - ), - "genre": ToolParamDefinition( - param_type="str", - description="The genre of the songs to return", - required=False, - ), - }, - ), - ] - ] diff --git a/llama_toolchain/inference/api/datatypes.py b/llama_toolchain/inference/api/datatypes.py index cad8f4377..571ecc3ea 100644 --- a/llama_toolchain/inference/api/datatypes.py +++ b/llama_toolchain/inference/api/datatypes.py @@ -15,41 +15,6 @@ from typing_extensions import Annotated from llama_models.llama3.api.datatypes import * # noqa: F403 -@json_schema_type -class ToolChoice(Enum): - auto = "auto" - required = "required" - - -@json_schema_type -class ToolPromptFormat(Enum): - """This Enum refers to the prompt format for calling zero shot tools - - `json` -- - Refers to the json format for calling tools. - The json format takes the form like - { - "type": "function", - "function" : { - "name": "function_name", - "description": "function_description", - "parameters": {...} - } - } - - `function_tag` -- - This is an example of how you could define - your own user defined format for making tool calls. - The function_tag format looks like this, - (parameters) - - The detailed prompts for each of these formats are defined in `system_prompt.py` - """ - - json = "json" - function_tag = "function_tag" - - class LogProbConfig(BaseModel): top_k: Optional[int] = 0 diff --git a/llama_toolchain/inference/api/endpoints.py b/llama_toolchain/inference/api/endpoints.py index 26773e439..a4c4d4095 100644 --- a/llama_toolchain/inference/api/endpoints.py +++ b/llama_toolchain/inference/api/endpoints.py @@ -7,7 +7,7 @@ from .datatypes import * # noqa: F403 from typing import Optional, Protocol -from llama_models.llama3.api.datatypes import ToolDefinition +from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat # this dependency is annoying and we need a forked up version anyway from llama_models.schema_utils import webmethod @@ -16,7 +16,7 @@ from llama_models.schema_utils import webmethod @json_schema_type class CompletionRequest(BaseModel): model: str - content: InterleavedTextAttachment + content: InterleavedTextMedia sampling_params: Optional[SamplingParams] = SamplingParams() stream: Optional[bool] = False @@ -41,7 +41,7 @@ class CompletionResponseStreamChunk(BaseModel): @json_schema_type class BatchCompletionRequest(BaseModel): model: str - content_batch: List[InterleavedTextAttachment] + content_batch: List[InterleavedTextMedia] sampling_params: Optional[SamplingParams] = SamplingParams() logprobs: Optional[LogProbConfig] = None diff --git a/llama_toolchain/inference/meta_reference/config.py b/llama_toolchain/inference/meta_reference/config.py index f85934118..d2e601680 100644 --- a/llama_toolchain/inference/meta_reference/config.py +++ b/llama_toolchain/inference/meta_reference/config.py @@ -11,10 +11,10 @@ from llama_models.datatypes import ModelFamily from llama_models.schema_utils import json_schema_type from llama_models.sku_list import all_registered_models -from llama_toolchain.inference.api import QuantizationConfig - from pydantic import BaseModel, Field, field_validator +from llama_toolchain.inference.api import QuantizationConfig + @json_schema_type class MetaReferenceImplConfig(BaseModel): diff --git a/llama_toolchain/inference/prepare_messages.py b/llama_toolchain/inference/prepare_messages.py index 83aff57f9..d5ce648e1 100644 --- a/llama_toolchain/inference/prepare_messages.py +++ b/llama_toolchain/inference/prepare_messages.py @@ -1,7 +1,8 @@ import textwrap +from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_toolchain.inference.api import * # noqa: F403 -from llama_toolchain.common.prompt_templates.system_prompts import ( +from llama_models.llama3.prompt_templates import ( BuiltinToolGenerator, FunctionTagCustomToolGenerator, JsonCustomToolGenerator, diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 41afb9db0..ea0246f20 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -15,7 +15,7 @@ from llama_toolchain.agentic_system.event_logger import EventLogger, LogEvent from llama_toolchain.agentic_system.utils import get_agent_system_instance from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_toolchain.agentic_system.api.datatypes import StepType, ToolPromptFormat +from llama_toolchain.agentic_system.api.datatypes import StepType from llama_toolchain.tools.custom.datatypes import CustomTool from tests.example_custom_tool import GetBoilingPointTool diff --git a/tests/test_inference.py b/tests/test_inference.py index 6dcd60f11..0a772d26e 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -14,13 +14,11 @@ from llama_models.llama3.api.datatypes import ( SystemMessage, ToolDefinition, ToolParamDefinition, + ToolPromptFormat, ToolResponseMessage, UserMessage, ) -from llama_toolchain.inference.api.datatypes import ( - ChatCompletionResponseEventType, - ToolPromptFormat, -) +from llama_toolchain.inference.api.datatypes import ChatCompletionResponseEventType from llama_toolchain.inference.api.endpoints import ChatCompletionRequest from llama_toolchain.inference.meta_reference.config import MetaReferenceImplConfig diff --git a/tests/test_ollama_inference.py b/tests/test_ollama_inference.py index 5ff1b94f9..8319cab3d 100644 --- a/tests/test_ollama_inference.py +++ b/tests/test_ollama_inference.py @@ -10,13 +10,11 @@ from llama_models.llama3.api.datatypes import ( SystemMessage, ToolDefinition, ToolParamDefinition, + ToolPromptFormat, ToolResponseMessage, UserMessage, ) -from llama_toolchain.inference.api.datatypes import ( - ChatCompletionResponseEventType, - ToolPromptFormat, -) +from llama_toolchain.inference.api.datatypes import ChatCompletionResponseEventType from llama_toolchain.inference.api.endpoints import ChatCompletionRequest from llama_toolchain.inference.ollama.config import OllamaImplConfig from llama_toolchain.inference.ollama.ollama import get_provider_impl diff --git a/tests/test_prompt_templates.py b/tests/test_prompt_templates.py deleted file mode 100644 index 94825e327..000000000 --- a/tests/test_prompt_templates.py +++ /dev/null @@ -1,101 +0,0 @@ -import textwrap -import unittest -from datetime import datetime - -from llama_toolchain.common.prompt_templates.system_prompts import ( - BuiltinToolGenerator, - FunctionTagCustomToolGenerator, - JsonCustomToolGenerator, - SystemDefaultGenerator, -) - - -class PromptTemplateTests(unittest.TestCase): - - def check_generator_output(self, generator, expected_text): - example = generator.data_examples()[0] - - pt = generator.gen(example) - text = pt.render() - # print(text) # debugging - self.assertEqual(text, expected_text) - - def test_system_default(self): - generator = SystemDefaultGenerator() - today = datetime.now().strftime("%d %B %Y") - expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}" - self.check_generator_output(generator, expected_text) - - def test_system_builtin_only(self): - generator = BuiltinToolGenerator() - expected_text = textwrap.dedent( - """ - Environment: ipython - Tools: brave_search, wolfram_alpha - """ - ) - self.check_generator_output(generator, expected_text.strip("\n")) - - def test_system_custom_only(self): - self.maxDiff = None - generator = JsonCustomToolGenerator() - expected_text = textwrap.dedent( - """ - Answer the user's question by making use of the following functions if needed. - If none of the function can be used, please say so. - Here is a list of functions in JSON format: - { - "type": "function", - "function": { - "name": "trending_songs", - "description": "Returns the trending songs on a Music site", - "parameters": { - "type": "object", - "properties": [ - { - "n": { - "type": "object", - "description": "The number of songs to return" - } - }, - { - "genre": { - "type": "object", - "description": "The genre of the songs to return" - } - } - ], - "required": ["n"] - } - } - } - - Return function calls in JSON format. - """ - ) - self.check_generator_output(generator, expected_text.strip("\n")) - - def test_system_custom_function_tag(self): - self.maxDiff = None - generator = FunctionTagCustomToolGenerator() - expected_text = textwrap.dedent( - """ - You have access to the following functions: - - Use the function 'trending_songs' to 'Returns the trending songs on a Music site': - {"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}} - - Think very carefully before calling functions. - If a you choose to call a function ONLY reply in the following format with no prefix or suffix: - - {"example_name": "example_value"} - - Reminder: - - If looking for real time information use relevant functions before falling back to brave_search - - Function calls MUST follow the specified format, start with - - Required parameters MUST be specified - - Only call one function at a time - - Put the entire function call reply on one line - """ - ) - self.check_generator_output(generator, expected_text.strip("\n"))