mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Moved ToolPromptFormat and jinja templates to llama_models.llama3.api
This commit is contained in:
parent
ab8193c88c
commit
5655266d58
13 changed files with 21 additions and 388 deletions
|
@ -17,6 +17,7 @@ from llama_models.llama3.api.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolParamDefinition,
|
ToolParamDefinition,
|
||||||
|
ToolPromptFormat,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
@ -32,7 +33,6 @@ from .api import (
|
||||||
AgenticSystemToolDefinition,
|
AgenticSystemToolDefinition,
|
||||||
AgenticSystemTurnCreateRequest,
|
AgenticSystemTurnCreateRequest,
|
||||||
AgenticSystemTurnResponseStreamChunk,
|
AgenticSystemTurnResponseStreamChunk,
|
||||||
ToolPromptFormat,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,8 @@ import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import AsyncGenerator, List, Optional
|
from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import ToolPromptFormat
|
||||||
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.api.datatypes import (
|
from llama_toolchain.agentic_system.api.datatypes import (
|
||||||
|
@ -26,12 +28,10 @@ from llama_toolchain.agentic_system.api.datatypes import (
|
||||||
ShieldCallStep,
|
ShieldCallStep,
|
||||||
StepType,
|
StepType,
|
||||||
ToolExecutionStep,
|
ToolExecutionStep,
|
||||||
ToolPromptFormat,
|
|
||||||
Turn,
|
Turn,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_toolchain.inference.api import ChatCompletionRequest, Inference
|
from llama_toolchain.inference.api import ChatCompletionRequest, Inference
|
||||||
|
|
||||||
from llama_toolchain.inference.api.datatypes import (
|
from llama_toolchain.inference.api.datatypes import (
|
||||||
Attachment,
|
Attachment,
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
|
|
|
@ -7,7 +7,12 @@
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import BuiltinTool, Message, SamplingParams
|
from llama_models.llama3.api.datatypes import (
|
||||||
|
BuiltinTool,
|
||||||
|
Message,
|
||||||
|
SamplingParams,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.api import (
|
from llama_toolchain.agentic_system.api import (
|
||||||
AgenticSystemCreateRequest,
|
AgenticSystemCreateRequest,
|
||||||
|
@ -15,7 +20,6 @@ from llama_toolchain.agentic_system.api import (
|
||||||
AgenticSystemSessionCreateRequest,
|
AgenticSystemSessionCreateRequest,
|
||||||
AgenticSystemToolDefinition,
|
AgenticSystemToolDefinition,
|
||||||
)
|
)
|
||||||
from llama_toolchain.agentic_system.api.datatypes import ToolPromptFormat
|
|
||||||
from llama_toolchain.agentic_system.client import AgenticSystemClient
|
from llama_toolchain.agentic_system.client import AgenticSystemClient
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.meta_reference.execute_with_custom_tools import (
|
from llama_toolchain.agentic_system.meta_reference.execute_with_custom_tools import (
|
||||||
|
|
|
@ -1,26 +0,0 @@
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Dict, List
|
|
||||||
|
|
||||||
from jinja2 import Template
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PromptTemplate:
|
|
||||||
template: str
|
|
||||||
data: Dict[str, Any]
|
|
||||||
|
|
||||||
def render(self):
|
|
||||||
template = Template(self.template)
|
|
||||||
return template.render(self.data)
|
|
||||||
|
|
||||||
|
|
||||||
class PromptTemplateGeneratorBase:
|
|
||||||
"""
|
|
||||||
Base class for prompt template generators.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def gen(self, *args, **kwargs) -> PromptTemplate:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def data_examples(self) -> List[Any]:
|
|
||||||
raise NotImplementedError()
|
|
|
@ -1,206 +0,0 @@
|
||||||
import textwrap
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any, Dict, List
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import (
|
|
||||||
BuiltinTool,
|
|
||||||
ToolDefinition,
|
|
||||||
ToolParamDefinition,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .base import PromptTemplate, PromptTemplateGeneratorBase
|
|
||||||
|
|
||||||
|
|
||||||
class SystemDefaultGenerator(PromptTemplateGeneratorBase):
|
|
||||||
|
|
||||||
def gen(self, *args, **kwargs) -> PromptTemplate:
|
|
||||||
template_str = textwrap.dedent(
|
|
||||||
"""
|
|
||||||
Cutting Knowledge Date: December 2023
|
|
||||||
Today Date: {{ today }}
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
return PromptTemplate(
|
|
||||||
template_str.lstrip("\n"),
|
|
||||||
{"today": datetime.now().strftime("%d %B %Y")},
|
|
||||||
)
|
|
||||||
|
|
||||||
def data_examples(self) -> List[Any]:
|
|
||||||
return [None]
|
|
||||||
|
|
||||||
|
|
||||||
class BuiltinToolGenerator(PromptTemplateGeneratorBase):
|
|
||||||
|
|
||||||
def _tool_breakdown(self, tools: List[ToolDefinition]):
|
|
||||||
builtin_tools, custom_tools = [], []
|
|
||||||
for dfn in tools:
|
|
||||||
if isinstance(dfn.tool_name, BuiltinTool):
|
|
||||||
builtin_tools.append(dfn)
|
|
||||||
else:
|
|
||||||
custom_tools.append(dfn)
|
|
||||||
|
|
||||||
return builtin_tools, custom_tools
|
|
||||||
|
|
||||||
def gen(self, tools: List[ToolDefinition]) -> PromptTemplate:
|
|
||||||
builtin_tools, custom_tools = self._tool_breakdown(tools)
|
|
||||||
data = []
|
|
||||||
template_str = textwrap.dedent(
|
|
||||||
"""
|
|
||||||
{% if builtin_tools or custom_tools -%}
|
|
||||||
Environment: ipython
|
|
||||||
{% endif -%}
|
|
||||||
{% set builtin_tools = builtin_tools | reject('equalto', 'code_interpreter') | list -%}
|
|
||||||
{% if builtin_tools -%}
|
|
||||||
Tools: {{ builtin_tools | join(", ") | trim -}}
|
|
||||||
{% endif %}
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
return PromptTemplate(
|
|
||||||
template_str.lstrip("\n"),
|
|
||||||
{
|
|
||||||
"builtin_tools": [t.tool_name.value for t in builtin_tools],
|
|
||||||
"custom_tools": custom_tools,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def data_examples(self) -> List[List[ToolDefinition]]:
|
|
||||||
return [
|
|
||||||
# builtin tools
|
|
||||||
[
|
|
||||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
|
||||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
|
||||||
ToolDefinition(tool_name=BuiltinTool.wolfram_alpha),
|
|
||||||
],
|
|
||||||
# only code interpretor
|
|
||||||
[
|
|
||||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
|
||||||
],
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
|
|
||||||
|
|
||||||
def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
|
||||||
template_str = textwrap.dedent(
|
|
||||||
"""
|
|
||||||
Answer the user's question by making use of the following functions if needed.
|
|
||||||
If none of the function can be used, please say so.
|
|
||||||
Here is a list of functions in JSON format:
|
|
||||||
{% for t in custom_tools -%}
|
|
||||||
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
|
||||||
{%- set tname = t.tool_name -%}
|
|
||||||
{%- set tdesc = t.description -%}
|
|
||||||
{%- set tparams = t.parameters -%}
|
|
||||||
{%- set required_params = [] -%}
|
|
||||||
{%- for name, param in tparams.items() if param.required == true -%}
|
|
||||||
{%- set _ = required_params.append(name) -%}
|
|
||||||
{%- endfor -%}
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "{{tname}}",
|
|
||||||
"description": "{{tdesc}}",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": [
|
|
||||||
{%- for name, param in tparams.items() %}
|
|
||||||
{
|
|
||||||
"{{name}}": {
|
|
||||||
"type": "object",
|
|
||||||
"description": "{{param.description}}"
|
|
||||||
}
|
|
||||||
}{% if not loop.last %},{% endif %}
|
|
||||||
{%- endfor %}
|
|
||||||
],
|
|
||||||
"required": {{ required_params | tojson }}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
{% endfor %}
|
|
||||||
Return function calls in JSON format.
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
return PromptTemplate(
|
|
||||||
template_str.lstrip("\n"),
|
|
||||||
{"custom_tools": [t.model_dump() for t in custom_tools]},
|
|
||||||
)
|
|
||||||
|
|
||||||
def data_examples(self) -> List[List[ToolDefinition]]:
|
|
||||||
return [
|
|
||||||
[
|
|
||||||
ToolDefinition(
|
|
||||||
tool_name="trending_songs",
|
|
||||||
description="Returns the trending songs on a Music site",
|
|
||||||
parameters={
|
|
||||||
"n": ToolParamDefinition(
|
|
||||||
param_type="int",
|
|
||||||
description="The number of songs to return",
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
"genre": ToolParamDefinition(
|
|
||||||
param_type="str",
|
|
||||||
description="The genre of the songs to return",
|
|
||||||
required=False,
|
|
||||||
),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
|
|
||||||
|
|
||||||
def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
|
||||||
template_str = textwrap.dedent(
|
|
||||||
"""
|
|
||||||
You have access to the following functions:
|
|
||||||
|
|
||||||
{% for t in custom_tools %}
|
|
||||||
{#- manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
|
||||||
{%- set tname = t.tool_name -%}
|
|
||||||
{%- set tdesc = t.description -%}
|
|
||||||
{%- set tparams = t.parameters | tojson -%}
|
|
||||||
Use the function '{{ tname }}' to '{{ tdesc }}':
|
|
||||||
{"name": "{{tname}}", "description": "{{tdesc}}", "parameters": {{tparams}}}
|
|
||||||
|
|
||||||
{% endfor -%}
|
|
||||||
Think very carefully before calling functions.
|
|
||||||
If a you choose to call a function ONLY reply in the following format with no prefix or suffix:
|
|
||||||
|
|
||||||
<function=example_function_name>{"example_name": "example_value"}</function>
|
|
||||||
|
|
||||||
Reminder:
|
|
||||||
- If looking for real time information use relevant functions before falling back to brave_search
|
|
||||||
- Function calls MUST follow the specified format, start with <function= and end with </function>
|
|
||||||
- Required parameters MUST be specified
|
|
||||||
- Only call one function at a time
|
|
||||||
- Put the entire function call reply on one line
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
return PromptTemplate(
|
|
||||||
template_str.lstrip("\n"),
|
|
||||||
{"custom_tools": [t.model_dump() for t in custom_tools]},
|
|
||||||
)
|
|
||||||
|
|
||||||
def data_examples(self) -> List[List[ToolDefinition]]:
|
|
||||||
return [
|
|
||||||
[
|
|
||||||
ToolDefinition(
|
|
||||||
tool_name="trending_songs",
|
|
||||||
description="Returns the trending songs on a Music site",
|
|
||||||
parameters={
|
|
||||||
"n": ToolParamDefinition(
|
|
||||||
param_type="int",
|
|
||||||
description="The number of songs to return",
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
"genre": ToolParamDefinition(
|
|
||||||
param_type="str",
|
|
||||||
description="The genre of the songs to return",
|
|
||||||
required=False,
|
|
||||||
),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
]
|
|
||||||
]
|
|
|
@ -15,41 +15,6 @@ from typing_extensions import Annotated
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ToolChoice(Enum):
|
|
||||||
auto = "auto"
|
|
||||||
required = "required"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ToolPromptFormat(Enum):
|
|
||||||
"""This Enum refers to the prompt format for calling zero shot tools
|
|
||||||
|
|
||||||
`json` --
|
|
||||||
Refers to the json format for calling tools.
|
|
||||||
The json format takes the form like
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function" : {
|
|
||||||
"name": "function_name",
|
|
||||||
"description": "function_description",
|
|
||||||
"parameters": {...}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
`function_tag` --
|
|
||||||
This is an example of how you could define
|
|
||||||
your own user defined format for making tool calls.
|
|
||||||
The function_tag format looks like this,
|
|
||||||
<function=function_name>(parameters)</function>
|
|
||||||
|
|
||||||
The detailed prompts for each of these formats are defined in `system_prompt.py`
|
|
||||||
"""
|
|
||||||
|
|
||||||
json = "json"
|
|
||||||
function_tag = "function_tag"
|
|
||||||
|
|
||||||
|
|
||||||
class LogProbConfig(BaseModel):
|
class LogProbConfig(BaseModel):
|
||||||
top_k: Optional[int] = 0
|
top_k: Optional[int] = 0
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
from .datatypes import * # noqa: F403
|
from .datatypes import * # noqa: F403
|
||||||
from typing import Optional, Protocol
|
from typing import Optional, Protocol
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import ToolDefinition
|
from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat
|
||||||
|
|
||||||
# this dependency is annoying and we need a forked up version anyway
|
# this dependency is annoying and we need a forked up version anyway
|
||||||
from llama_models.schema_utils import webmethod
|
from llama_models.schema_utils import webmethod
|
||||||
|
@ -16,7 +16,7 @@ from llama_models.schema_utils import webmethod
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class CompletionRequest(BaseModel):
|
class CompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
content: InterleavedTextAttachment
|
content: InterleavedTextMedia
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||||
|
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
|
@ -41,7 +41,7 @@ class CompletionResponseStreamChunk(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BatchCompletionRequest(BaseModel):
|
class BatchCompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
content_batch: List[InterleavedTextAttachment]
|
content_batch: List[InterleavedTextMedia]
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||||
logprobs: Optional[LogProbConfig] = None
|
logprobs: Optional[LogProbConfig] = None
|
||||||
|
|
||||||
|
|
|
@ -11,10 +11,10 @@ from llama_models.datatypes import ModelFamily
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from llama_models.sku_list import all_registered_models
|
from llama_models.sku_list import all_registered_models
|
||||||
|
|
||||||
from llama_toolchain.inference.api import QuantizationConfig
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
from llama_toolchain.inference.api import QuantizationConfig
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class MetaReferenceImplConfig(BaseModel):
|
class MetaReferenceImplConfig(BaseModel):
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_toolchain.inference.api import * # noqa: F403
|
from llama_toolchain.inference.api import * # noqa: F403
|
||||||
from llama_toolchain.common.prompt_templates.system_prompts import (
|
from llama_models.llama3.prompt_templates import (
|
||||||
BuiltinToolGenerator,
|
BuiltinToolGenerator,
|
||||||
FunctionTagCustomToolGenerator,
|
FunctionTagCustomToolGenerator,
|
||||||
JsonCustomToolGenerator,
|
JsonCustomToolGenerator,
|
||||||
|
|
|
@ -15,7 +15,7 @@ from llama_toolchain.agentic_system.event_logger import EventLogger, LogEvent
|
||||||
from llama_toolchain.agentic_system.utils import get_agent_system_instance
|
from llama_toolchain.agentic_system.utils import get_agent_system_instance
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_toolchain.agentic_system.api.datatypes import StepType, ToolPromptFormat
|
from llama_toolchain.agentic_system.api.datatypes import StepType
|
||||||
from llama_toolchain.tools.custom.datatypes import CustomTool
|
from llama_toolchain.tools.custom.datatypes import CustomTool
|
||||||
|
|
||||||
from tests.example_custom_tool import GetBoilingPointTool
|
from tests.example_custom_tool import GetBoilingPointTool
|
||||||
|
|
|
@ -14,13 +14,11 @@ from llama_models.llama3.api.datatypes import (
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolParamDefinition,
|
ToolParamDefinition,
|
||||||
|
ToolPromptFormat,
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_toolchain.inference.api.datatypes import (
|
from llama_toolchain.inference.api.datatypes import ChatCompletionResponseEventType
|
||||||
ChatCompletionResponseEventType,
|
|
||||||
ToolPromptFormat,
|
|
||||||
)
|
|
||||||
|
|
||||||
from llama_toolchain.inference.api.endpoints import ChatCompletionRequest
|
from llama_toolchain.inference.api.endpoints import ChatCompletionRequest
|
||||||
from llama_toolchain.inference.meta_reference.config import MetaReferenceImplConfig
|
from llama_toolchain.inference.meta_reference.config import MetaReferenceImplConfig
|
||||||
|
|
|
@ -10,13 +10,11 @@ from llama_models.llama3.api.datatypes import (
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolParamDefinition,
|
ToolParamDefinition,
|
||||||
|
ToolPromptFormat,
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_toolchain.inference.api.datatypes import (
|
from llama_toolchain.inference.api.datatypes import ChatCompletionResponseEventType
|
||||||
ChatCompletionResponseEventType,
|
|
||||||
ToolPromptFormat,
|
|
||||||
)
|
|
||||||
from llama_toolchain.inference.api.endpoints import ChatCompletionRequest
|
from llama_toolchain.inference.api.endpoints import ChatCompletionRequest
|
||||||
from llama_toolchain.inference.ollama.config import OllamaImplConfig
|
from llama_toolchain.inference.ollama.config import OllamaImplConfig
|
||||||
from llama_toolchain.inference.ollama.ollama import get_provider_impl
|
from llama_toolchain.inference.ollama.ollama import get_provider_impl
|
||||||
|
|
|
@ -1,101 +0,0 @@
|
||||||
import textwrap
|
|
||||||
import unittest
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from llama_toolchain.common.prompt_templates.system_prompts import (
|
|
||||||
BuiltinToolGenerator,
|
|
||||||
FunctionTagCustomToolGenerator,
|
|
||||||
JsonCustomToolGenerator,
|
|
||||||
SystemDefaultGenerator,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PromptTemplateTests(unittest.TestCase):
|
|
||||||
|
|
||||||
def check_generator_output(self, generator, expected_text):
|
|
||||||
example = generator.data_examples()[0]
|
|
||||||
|
|
||||||
pt = generator.gen(example)
|
|
||||||
text = pt.render()
|
|
||||||
# print(text) # debugging
|
|
||||||
self.assertEqual(text, expected_text)
|
|
||||||
|
|
||||||
def test_system_default(self):
|
|
||||||
generator = SystemDefaultGenerator()
|
|
||||||
today = datetime.now().strftime("%d %B %Y")
|
|
||||||
expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
|
|
||||||
self.check_generator_output(generator, expected_text)
|
|
||||||
|
|
||||||
def test_system_builtin_only(self):
|
|
||||||
generator = BuiltinToolGenerator()
|
|
||||||
expected_text = textwrap.dedent(
|
|
||||||
"""
|
|
||||||
Environment: ipython
|
|
||||||
Tools: brave_search, wolfram_alpha
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
|
||||||
|
|
||||||
def test_system_custom_only(self):
|
|
||||||
self.maxDiff = None
|
|
||||||
generator = JsonCustomToolGenerator()
|
|
||||||
expected_text = textwrap.dedent(
|
|
||||||
"""
|
|
||||||
Answer the user's question by making use of the following functions if needed.
|
|
||||||
If none of the function can be used, please say so.
|
|
||||||
Here is a list of functions in JSON format:
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "trending_songs",
|
|
||||||
"description": "Returns the trending songs on a Music site",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": [
|
|
||||||
{
|
|
||||||
"n": {
|
|
||||||
"type": "object",
|
|
||||||
"description": "The number of songs to return"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"genre": {
|
|
||||||
"type": "object",
|
|
||||||
"description": "The genre of the songs to return"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"required": ["n"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Return function calls in JSON format.
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
|
||||||
|
|
||||||
def test_system_custom_function_tag(self):
|
|
||||||
self.maxDiff = None
|
|
||||||
generator = FunctionTagCustomToolGenerator()
|
|
||||||
expected_text = textwrap.dedent(
|
|
||||||
"""
|
|
||||||
You have access to the following functions:
|
|
||||||
|
|
||||||
Use the function 'trending_songs' to 'Returns the trending songs on a Music site':
|
|
||||||
{"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}}
|
|
||||||
|
|
||||||
Think very carefully before calling functions.
|
|
||||||
If a you choose to call a function ONLY reply in the following format with no prefix or suffix:
|
|
||||||
|
|
||||||
<function=example_function_name>{"example_name": "example_value"}</function>
|
|
||||||
|
|
||||||
Reminder:
|
|
||||||
- If looking for real time information use relevant functions before falling back to brave_search
|
|
||||||
- Function calls MUST follow the specified format, start with <function= and end with </function>
|
|
||||||
- Required parameters MUST be specified
|
|
||||||
- Only call one function at a time
|
|
||||||
- Put the entire function call reply on one line
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
|
Loading…
Add table
Add a link
Reference in a new issue