diff --git a/llama_toolchain/agentic_system/client.py b/llama_toolchain/agentic_system/client.py
index 56428c425..690a8499b 100644
--- a/llama_toolchain/agentic_system/client.py
+++ b/llama_toolchain/agentic_system/client.py
@@ -17,6 +17,7 @@ from llama_models.llama3.api.datatypes import (
BuiltinTool,
SamplingParams,
ToolParamDefinition,
+ ToolPromptFormat,
UserMessage,
)
from termcolor import cprint
@@ -32,7 +33,6 @@ from .api import (
AgenticSystemToolDefinition,
AgenticSystemTurnCreateRequest,
AgenticSystemTurnResponseStreamChunk,
- ToolPromptFormat,
)
diff --git a/llama_toolchain/agentic_system/meta_reference/agent_instance.py b/llama_toolchain/agentic_system/meta_reference/agent_instance.py
index 5de17d7b9..2c769a5e1 100644
--- a/llama_toolchain/agentic_system/meta_reference/agent_instance.py
+++ b/llama_toolchain/agentic_system/meta_reference/agent_instance.py
@@ -10,6 +10,8 @@ import uuid
from datetime import datetime
from typing import AsyncGenerator, List, Optional
+from llama_models.llama3.api.datatypes import ToolPromptFormat
+
from termcolor import cprint
from llama_toolchain.agentic_system.api.datatypes import (
@@ -26,12 +28,10 @@ from llama_toolchain.agentic_system.api.datatypes import (
ShieldCallStep,
StepType,
ToolExecutionStep,
- ToolPromptFormat,
Turn,
)
from llama_toolchain.inference.api import ChatCompletionRequest, Inference
-
from llama_toolchain.inference.api.datatypes import (
Attachment,
BuiltinTool,
diff --git a/llama_toolchain/agentic_system/utils.py b/llama_toolchain/agentic_system/utils.py
index b2ba4fec8..73fe9f918 100644
--- a/llama_toolchain/agentic_system/utils.py
+++ b/llama_toolchain/agentic_system/utils.py
@@ -7,7 +7,12 @@
import uuid
from typing import Any, List, Optional
-from llama_models.llama3.api.datatypes import BuiltinTool, Message, SamplingParams
+from llama_models.llama3.api.datatypes import (
+ BuiltinTool,
+ Message,
+ SamplingParams,
+ ToolPromptFormat,
+)
from llama_toolchain.agentic_system.api import (
AgenticSystemCreateRequest,
@@ -15,7 +20,6 @@ from llama_toolchain.agentic_system.api import (
AgenticSystemSessionCreateRequest,
AgenticSystemToolDefinition,
)
-from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat
from llama_toolchain.agentic_system.client import AgenticSystemClient
from llama_toolchain.agentic_system.meta_reference.execute_with_custom_tools import (
diff --git a/llama_toolchain/common/prompt_templates/base.py b/llama_toolchain/common/prompt_templates/base.py
deleted file mode 100644
index de229bcb2..000000000
--- a/llama_toolchain/common/prompt_templates/base.py
+++ /dev/null
@@ -1,26 +0,0 @@
-from dataclasses import dataclass
-from typing import Any, Dict, List
-
-from jinja2 import Template
-
-
-@dataclass
-class PromptTemplate:
- template: str
- data: Dict[str, Any]
-
- def render(self):
- template = Template(self.template)
- return template.render(self.data)
-
-
-class PromptTemplateGeneratorBase:
- """
- Base class for prompt template generators.
- """
-
- def gen(self, *args, **kwargs) -> PromptTemplate:
- raise NotImplementedError()
-
- def data_examples(self) -> List[Any]:
- raise NotImplementedError()
diff --git a/llama_toolchain/common/prompt_templates/system_prompts.py b/llama_toolchain/common/prompt_templates/system_prompts.py
deleted file mode 100644
index aeabb116e..000000000
--- a/llama_toolchain/common/prompt_templates/system_prompts.py
+++ /dev/null
@@ -1,206 +0,0 @@
-import textwrap
-from datetime import datetime
-from typing import Any, Dict, List
-
-from llama_models.llama3.api.datatypes import (
- BuiltinTool,
- ToolDefinition,
- ToolParamDefinition,
-)
-
-from .base import PromptTemplate, PromptTemplateGeneratorBase
-
-
-class SystemDefaultGenerator(PromptTemplateGeneratorBase):
-
- def gen(self, *args, **kwargs) -> PromptTemplate:
- template_str = textwrap.dedent(
- """
- Cutting Knowledge Date: December 2023
- Today Date: {{ today }}
- """
- )
- return PromptTemplate(
- template_str.lstrip("\n"),
- {"today": datetime.now().strftime("%d %B %Y")},
- )
-
- def data_examples(self) -> List[Any]:
- return [None]
-
-
-class BuiltinToolGenerator(PromptTemplateGeneratorBase):
-
- def _tool_breakdown(self, tools: List[ToolDefinition]):
- builtin_tools, custom_tools = [], []
- for dfn in tools:
- if isinstance(dfn.tool_name, BuiltinTool):
- builtin_tools.append(dfn)
- else:
- custom_tools.append(dfn)
-
- return builtin_tools, custom_tools
-
- def gen(self, tools: List[ToolDefinition]) -> PromptTemplate:
- builtin_tools, custom_tools = self._tool_breakdown(tools)
- data = []
- template_str = textwrap.dedent(
- """
- {% if builtin_tools or custom_tools -%}
- Environment: ipython
- {% endif -%}
- {% set builtin_tools = builtin_tools | reject('equalto', 'code_interpreter') | list -%}
- {% if builtin_tools -%}
- Tools: {{ builtin_tools | join(", ") | trim -}}
- {% endif %}
- """
- )
- return PromptTemplate(
- template_str.lstrip("\n"),
- {
- "builtin_tools": [t.tool_name.value for t in builtin_tools],
- "custom_tools": custom_tools,
- },
- )
-
- def data_examples(self) -> List[List[ToolDefinition]]:
- return [
- # builtin tools
- [
- ToolDefinition(tool_name=BuiltinTool.code_interpreter),
- ToolDefinition(tool_name=BuiltinTool.brave_search),
- ToolDefinition(tool_name=BuiltinTool.wolfram_alpha),
- ],
- # only code interpretor
- [
- ToolDefinition(tool_name=BuiltinTool.code_interpreter),
- ],
- ]
-
-
-class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
-
- def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
- template_str = textwrap.dedent(
- """
- Answer the user's question by making use of the following functions if needed.
- If none of the function can be used, please say so.
- Here is a list of functions in JSON format:
- {% for t in custom_tools -%}
- {# manually setting up JSON because jinja sorts keys in unexpected ways -#}
- {%- set tname = t.tool_name -%}
- {%- set tdesc = t.description -%}
- {%- set tparams = t.parameters -%}
- {%- set required_params = [] -%}
- {%- for name, param in tparams.items() if param.required == true -%}
- {%- set _ = required_params.append(name) -%}
- {%- endfor -%}
- {
- "type": "function",
- "function": {
- "name": "{{tname}}",
- "description": "{{tdesc}}",
- "parameters": {
- "type": "object",
- "properties": [
- {%- for name, param in tparams.items() %}
- {
- "{{name}}": {
- "type": "object",
- "description": "{{param.description}}"
- }
- }{% if not loop.last %},{% endif %}
- {%- endfor %}
- ],
- "required": {{ required_params | tojson }}
- }
- }
- }
- {% endfor %}
- Return function calls in JSON format.
- """
- )
-
- return PromptTemplate(
- template_str.lstrip("\n"),
- {"custom_tools": [t.model_dump() for t in custom_tools]},
- )
-
- def data_examples(self) -> List[List[ToolDefinition]]:
- return [
- [
- ToolDefinition(
- tool_name="trending_songs",
- description="Returns the trending songs on a Music site",
- parameters={
- "n": ToolParamDefinition(
- param_type="int",
- description="The number of songs to return",
- required=True,
- ),
- "genre": ToolParamDefinition(
- param_type="str",
- description="The genre of the songs to return",
- required=False,
- ),
- },
- ),
- ]
- ]
-
-
-class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
-
- def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
- template_str = textwrap.dedent(
- """
- You have access to the following functions:
-
- {% for t in custom_tools %}
- {#- manually setting up JSON because jinja sorts keys in unexpected ways -#}
- {%- set tname = t.tool_name -%}
- {%- set tdesc = t.description -%}
- {%- set tparams = t.parameters | tojson -%}
- Use the function '{{ tname }}' to '{{ tdesc }}':
- {"name": "{{tname}}", "description": "{{tdesc}}", "parameters": {{tparams}}}
-
- {% endfor -%}
- Think very carefully before calling functions.
- If a you choose to call a function ONLY reply in the following format with no prefix or suffix:
-
- {"example_name": "example_value"}
-
- Reminder:
- - If looking for real time information use relevant functions before falling back to brave_search
- - Function calls MUST follow the specified format, start with
- - Required parameters MUST be specified
- - Only call one function at a time
- - Put the entire function call reply on one line
- """
- )
- return PromptTemplate(
- template_str.lstrip("\n"),
- {"custom_tools": [t.model_dump() for t in custom_tools]},
- )
-
- def data_examples(self) -> List[List[ToolDefinition]]:
- return [
- [
- ToolDefinition(
- tool_name="trending_songs",
- description="Returns the trending songs on a Music site",
- parameters={
- "n": ToolParamDefinition(
- param_type="int",
- description="The number of songs to return",
- required=True,
- ),
- "genre": ToolParamDefinition(
- param_type="str",
- description="The genre of the songs to return",
- required=False,
- ),
- },
- ),
- ]
- ]
diff --git a/llama_toolchain/inference/api/datatypes.py b/llama_toolchain/inference/api/datatypes.py
index cad8f4377..571ecc3ea 100644
--- a/llama_toolchain/inference/api/datatypes.py
+++ b/llama_toolchain/inference/api/datatypes.py
@@ -15,41 +15,6 @@ from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
-@json_schema_type
-class ToolChoice(Enum):
- auto = "auto"
- required = "required"
-
-
-@json_schema_type
-class ToolPromptFormat(Enum):
- """This Enum refers to the prompt format for calling zero shot tools
-
- `json` --
- Refers to the json format for calling tools.
- The json format takes the form like
- {
- "type": "function",
- "function" : {
- "name": "function_name",
- "description": "function_description",
- "parameters": {...}
- }
- }
-
- `function_tag` --
- This is an example of how you could define
- your own user defined format for making tool calls.
- The function_tag format looks like this,
- (parameters)
-
- The detailed prompts for each of these formats are defined in `system_prompt.py`
- """
-
- json = "json"
- function_tag = "function_tag"
-
-
class LogProbConfig(BaseModel):
top_k: Optional[int] = 0
diff --git a/llama_toolchain/inference/api/endpoints.py b/llama_toolchain/inference/api/endpoints.py
index 26773e439..a4c4d4095 100644
--- a/llama_toolchain/inference/api/endpoints.py
+++ b/llama_toolchain/inference/api/endpoints.py
@@ -7,7 +7,7 @@
from .datatypes import * # noqa: F403
from typing import Optional, Protocol
-from llama_models.llama3.api.datatypes import ToolDefinition
+from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat
# this dependency is annoying and we need a forked up version anyway
from llama_models.schema_utils import webmethod
@@ -16,7 +16,7 @@ from llama_models.schema_utils import webmethod
@json_schema_type
class CompletionRequest(BaseModel):
model: str
- content: InterleavedTextAttachment
+ content: InterleavedTextMedia
sampling_params: Optional[SamplingParams] = SamplingParams()
stream: Optional[bool] = False
@@ -41,7 +41,7 @@ class CompletionResponseStreamChunk(BaseModel):
@json_schema_type
class BatchCompletionRequest(BaseModel):
model: str
- content_batch: List[InterleavedTextAttachment]
+ content_batch: List[InterleavedTextMedia]
sampling_params: Optional[SamplingParams] = SamplingParams()
logprobs: Optional[LogProbConfig] = None
diff --git a/llama_toolchain/inference/meta_reference/config.py b/llama_toolchain/inference/meta_reference/config.py
index f85934118..d2e601680 100644
--- a/llama_toolchain/inference/meta_reference/config.py
+++ b/llama_toolchain/inference/meta_reference/config.py
@@ -11,10 +11,10 @@ from llama_models.datatypes import ModelFamily
from llama_models.schema_utils import json_schema_type
from llama_models.sku_list import all_registered_models
-from llama_toolchain.inference.api import QuantizationConfig
-
from pydantic import BaseModel, Field, field_validator
+from llama_toolchain.inference.api import QuantizationConfig
+
@json_schema_type
class MetaReferenceImplConfig(BaseModel):
diff --git a/llama_toolchain/inference/prepare_messages.py b/llama_toolchain/inference/prepare_messages.py
index 83aff57f9..d5ce648e1 100644
--- a/llama_toolchain/inference/prepare_messages.py
+++ b/llama_toolchain/inference/prepare_messages.py
@@ -1,7 +1,8 @@
import textwrap
+from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.inference.api import * # noqa: F403
-from llama_toolchain.common.prompt_templates.system_prompts import (
+from llama_models.llama3.prompt_templates import (
BuiltinToolGenerator,
FunctionTagCustomToolGenerator,
JsonCustomToolGenerator,
diff --git a/tests/test_e2e.py b/tests/test_e2e.py
index 41afb9db0..ea0246f20 100644
--- a/tests/test_e2e.py
+++ b/tests/test_e2e.py
@@ -15,7 +15,7 @@ from llama_toolchain.agentic_system.event_logger import EventLogger, LogEvent
from llama_toolchain.agentic_system.utils import get_agent_system_instance
from llama_models.llama3.api.datatypes import * # noqa: F403
-from llama_toolchain.agentic_system.api.datatypes import StepType, ToolPromptFormat
+from llama_toolchain.agentic_system.api.datatypes import StepType
from llama_toolchain.tools.custom.datatypes import CustomTool
from tests.example_custom_tool import GetBoilingPointTool
diff --git a/tests/test_inference.py b/tests/test_inference.py
index 6dcd60f11..0a772d26e 100644
--- a/tests/test_inference.py
+++ b/tests/test_inference.py
@@ -14,13 +14,11 @@ from llama_models.llama3.api.datatypes import (
SystemMessage,
ToolDefinition,
ToolParamDefinition,
+ ToolPromptFormat,
ToolResponseMessage,
UserMessage,
)
-from llama_toolchain.inference.api.datatypes import (
- ChatCompletionResponseEventType,
- ToolPromptFormat,
-)
+from llama_toolchain.inference.api.datatypes import ChatCompletionResponseEventType
from llama_toolchain.inference.api.endpoints import ChatCompletionRequest
from llama_toolchain.inference.meta_reference.config import MetaReferenceImplConfig
diff --git a/tests/test_ollama_inference.py b/tests/test_ollama_inference.py
index 5ff1b94f9..8319cab3d 100644
--- a/tests/test_ollama_inference.py
+++ b/tests/test_ollama_inference.py
@@ -10,13 +10,11 @@ from llama_models.llama3.api.datatypes import (
SystemMessage,
ToolDefinition,
ToolParamDefinition,
+ ToolPromptFormat,
ToolResponseMessage,
UserMessage,
)
-from llama_toolchain.inference.api.datatypes import (
- ChatCompletionResponseEventType,
- ToolPromptFormat,
-)
+from llama_toolchain.inference.api.datatypes import ChatCompletionResponseEventType
from llama_toolchain.inference.api.endpoints import ChatCompletionRequest
from llama_toolchain.inference.ollama.config import OllamaImplConfig
from llama_toolchain.inference.ollama.ollama import get_provider_impl
diff --git a/tests/test_prompt_templates.py b/tests/test_prompt_templates.py
deleted file mode 100644
index 94825e327..000000000
--- a/tests/test_prompt_templates.py
+++ /dev/null
@@ -1,101 +0,0 @@
-import textwrap
-import unittest
-from datetime import datetime
-
-from llama_toolchain.common.prompt_templates.system_prompts import (
- BuiltinToolGenerator,
- FunctionTagCustomToolGenerator,
- JsonCustomToolGenerator,
- SystemDefaultGenerator,
-)
-
-
-class PromptTemplateTests(unittest.TestCase):
-
- def check_generator_output(self, generator, expected_text):
- example = generator.data_examples()[0]
-
- pt = generator.gen(example)
- text = pt.render()
- # print(text) # debugging
- self.assertEqual(text, expected_text)
-
- def test_system_default(self):
- generator = SystemDefaultGenerator()
- today = datetime.now().strftime("%d %B %Y")
- expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
- self.check_generator_output(generator, expected_text)
-
- def test_system_builtin_only(self):
- generator = BuiltinToolGenerator()
- expected_text = textwrap.dedent(
- """
- Environment: ipython
- Tools: brave_search, wolfram_alpha
- """
- )
- self.check_generator_output(generator, expected_text.strip("\n"))
-
- def test_system_custom_only(self):
- self.maxDiff = None
- generator = JsonCustomToolGenerator()
- expected_text = textwrap.dedent(
- """
- Answer the user's question by making use of the following functions if needed.
- If none of the function can be used, please say so.
- Here is a list of functions in JSON format:
- {
- "type": "function",
- "function": {
- "name": "trending_songs",
- "description": "Returns the trending songs on a Music site",
- "parameters": {
- "type": "object",
- "properties": [
- {
- "n": {
- "type": "object",
- "description": "The number of songs to return"
- }
- },
- {
- "genre": {
- "type": "object",
- "description": "The genre of the songs to return"
- }
- }
- ],
- "required": ["n"]
- }
- }
- }
-
- Return function calls in JSON format.
- """
- )
- self.check_generator_output(generator, expected_text.strip("\n"))
-
- def test_system_custom_function_tag(self):
- self.maxDiff = None
- generator = FunctionTagCustomToolGenerator()
- expected_text = textwrap.dedent(
- """
- You have access to the following functions:
-
- Use the function 'trending_songs' to 'Returns the trending songs on a Music site':
- {"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}}
-
- Think very carefully before calling functions.
- If a you choose to call a function ONLY reply in the following format with no prefix or suffix:
-
- {"example_name": "example_value"}
-
- Reminder:
- - If looking for real time information use relevant functions before falling back to brave_search
- - Function calls MUST follow the specified format, start with
- - Required parameters MUST be specified
- - Only call one function at a time
- - Put the entire function call reply on one line
- """
- )
- self.check_generator_output(generator, expected_text.strip("\n"))