mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-01 20:00:01 +00:00
feat: RFC: tools API rework
# What does this PR do?
This PR proposes updates to the tools API in Inference and Agent.
Goals:
1. Agent's tool specification should be consistent with Inference's tool spec, but with add-ons.
2. Formal types should be defined for built in tools. Currently Agent tools args are untyped, e.g. how does one know that `builtin::rag_tool` takes a `vector_db_ids` param or even how to know 'builtin::rag_tool' is even available (in code, outside of docs)?
Inference:
1. BuiltinTool is to be removed and replaced by a formal `type` parameter.
2. 'brave_search' is replaced by 'web_search' to be more generic. It will still be translated back to brave_search when the prompt is constructed to be consistent with model training.
3. I'm not sure what `photogen` is. Maybe it can be removed?
Agent:
1. Uses the same format as in Inference for builtin tools.
2. New tools types are added, i.e. knowledge_sesarch (currently rag_tool), and MCP tool.
3. Toolgroup as a concept will be removed since it's really only used for MCP.
4. Instead MCPTool is its own type and available tools provided by the server will be expanded by default. Users can specify a subset of tool names if desired.
Example snippet:
```
agent = Agent(
client,
model=model_id,
instructions="You are a helpful assistant. Use the tools you have access to for providing relevant answers.",
tools=[
KnowledgeSearchTool(vector_store_id="1234"),
KnowledgeSearchTool(vector_store_id="5678", name="paper_search", description="Search research papers"),
KnowledgeSearchTool(vector_store_id="1357", name="wiki_search", description="Search wiki pages"),
# no need to register toolgroup, just pass in the server uri
# all available tools will be used
MCPTool(server_uri="http://localhost:8000/sse"),
# can specify a subset of available tools
MCPTool(server_uri="http://localhost:8000/sse", tool_names=["list_directory"]),
MCPTool(server_uri="http://localhost:8000/sse", tool_names=["list_directory"]),
# custom tool
my_custom_tool,
]
)
```
## Test Plan
# What does this PR do?
## Test Plan
# What does this PR do?
## Test Plan
This commit is contained in:
parent
39e094736f
commit
7027b537e0
22 changed files with 951 additions and 525 deletions
|
|
@ -33,10 +33,10 @@ class Role(Enum):
|
|||
tool = "tool"
|
||||
|
||||
|
||||
class BuiltinTool(Enum):
|
||||
brave_search = "brave_search"
|
||||
class ToolType(Enum):
|
||||
function = "function"
|
||||
web_search = "web_search"
|
||||
wolfram_alpha = "wolfram_alpha"
|
||||
photogen = "photogen"
|
||||
code_interpreter = "code_interpreter"
|
||||
|
||||
|
||||
|
|
@ -45,8 +45,9 @@ RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
|
|||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
type: ToolType
|
||||
call_id: str
|
||||
tool_name: Union[BuiltinTool, str]
|
||||
tool_name: str
|
||||
# Plan is to deprecate the Dict in favor of a JSON string
|
||||
# that is parsed on the client side instead of trying to manage
|
||||
# the recursive type here.
|
||||
|
|
@ -59,12 +60,18 @@ class ToolCall(BaseModel):
|
|||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
def validate_field(cls, v):
|
||||
# for backwards compatibility, we allow the tool name to be a string or a BuiltinTool
|
||||
# TODO: remove ToolDefinitionDeprecated in v0.1.10
|
||||
tool_name = v
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return BuiltinTool(v)
|
||||
tool_name = BuiltinTool(v)
|
||||
except ValueError:
|
||||
return v
|
||||
return v
|
||||
pass
|
||||
|
||||
if isinstance(tool_name, BuiltinTool):
|
||||
return tool_name.to_tool().type
|
||||
return tool_name
|
||||
|
||||
|
||||
class ToolPromptFormat(Enum):
|
||||
|
|
@ -151,8 +158,136 @@ class ToolParamDefinition(BaseModel):
|
|||
default: Optional[Any] = None
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
type: ToolType
|
||||
|
||||
@classmethod
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
required_properties = ["name", "description", "parameters"]
|
||||
for prop in required_properties:
|
||||
has_property = any(isinstance(v, property) for v in [cls.__dict__.get(prop)])
|
||||
has_field = prop in cls.__annotations__ or prop in cls.__dict__
|
||||
if not has_property and not has_field:
|
||||
raise TypeError(f"Class {cls.__name__} must implement '{prop}' property or field")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolDefinition(BaseModel):
|
||||
class WebSearchTool(Tool):
|
||||
type: Literal[ToolType.web_search.value] = ToolType.web_search.value
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "web_search"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Search the web for information"
|
||||
|
||||
@property
|
||||
def parameters(self) -> Dict[str, ToolParamDefinition]:
|
||||
return {
|
||||
"query": ToolParamDefinition(
|
||||
description="The query to search for",
|
||||
param_type="string",
|
||||
required=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class WolframAlphaTool(Tool):
|
||||
type: Literal[ToolType.wolfram_alpha.value] = ToolType.wolfram_alpha.value
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "wolfram_alpha"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Query WolframAlpha for computational knowledge"
|
||||
|
||||
@property
|
||||
def parameters(self) -> Dict[str, ToolParamDefinition]:
|
||||
return {
|
||||
"query": ToolParamDefinition(
|
||||
description="The query to compute",
|
||||
param_type="string",
|
||||
required=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CodeInterpreterTool(Tool):
|
||||
type: Literal[ToolType.code_interpreter.value] = ToolType.code_interpreter.value
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "code_interpreter"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Execute code"
|
||||
|
||||
@property
|
||||
def parameters(self) -> Dict[str, ToolParamDefinition]:
|
||||
return {
|
||||
"code": ToolParamDefinition(
|
||||
description="The code to execute",
|
||||
param_type="string",
|
||||
required=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FunctionTool(Tool):
|
||||
type: Literal[ToolType.function.value] = ToolType.function.value
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
parameters: Optional[Dict[str, ToolParamDefinition]] = None
|
||||
|
||||
@field_validator("name", mode="before")
|
||||
@classmethod
|
||||
def validate_name(cls, v):
|
||||
if v in ToolType.__members__:
|
||||
raise ValueError(f"Tool name '{v}' is a tool type and cannot be used as a name of a function tool")
|
||||
return v
|
||||
|
||||
|
||||
ToolDefinition = Annotated[
|
||||
Union[WebSearchTool, WolframAlphaTool, CodeInterpreterTool, FunctionTool], Field(discriminator="type")
|
||||
]
|
||||
|
||||
|
||||
# TODO: remove ToolDefinitionDeprecated in v0.1.10
|
||||
class BuiltinTool(Enum):
|
||||
brave_search = "brave_search"
|
||||
wolfram_alpha = "wolfram_alpha"
|
||||
code_interpreter = "code_interpreter"
|
||||
|
||||
def to_tool_type(self) -> ToolType:
|
||||
if self == BuiltinTool.brave_search:
|
||||
return ToolType.web_search
|
||||
elif self == BuiltinTool.wolfram_alpha:
|
||||
return ToolType.wolfram_alpha
|
||||
elif self == BuiltinTool.code_interpreter:
|
||||
return ToolType.code_interpreter
|
||||
|
||||
def to_tool(self) -> WebSearchTool | WolframAlphaTool | CodeInterpreterTool:
|
||||
if self == BuiltinTool.brave_search:
|
||||
return WebSearchTool()
|
||||
elif self == BuiltinTool.wolfram_alpha:
|
||||
return WolframAlphaTool()
|
||||
elif self == BuiltinTool.code_interpreter:
|
||||
return CodeInterpreterTool()
|
||||
|
||||
|
||||
# TODO: remove ToolDefinitionDeprecated in v0.1.10
|
||||
@json_schema_type
|
||||
class ToolDefinitionDeprecated(BaseModel):
|
||||
tool_name: Union[BuiltinTool, str]
|
||||
description: Optional[str] = None
|
||||
parameters: Optional[Dict[str, ToolParamDefinition]] = None
|
||||
|
|
@ -167,6 +302,21 @@ class ToolDefinition(BaseModel):
|
|||
return v
|
||||
return v
|
||||
|
||||
def to_tool_definition(self) -> ToolDefinition:
|
||||
# convert to ToolDefinition
|
||||
if self.tool_name == BuiltinTool.brave_search:
|
||||
return WebSearchTool()
|
||||
elif self.tool_name == BuiltinTool.code_interpreter:
|
||||
return CodeInterpreterTool()
|
||||
elif self.tool_name == BuiltinTool.wolfram_alpha:
|
||||
return WolframAlphaTool()
|
||||
else:
|
||||
return FunctionTool(
|
||||
name=self.tool_name,
|
||||
description=self.description,
|
||||
parameters=self.parameters,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GreedySamplingStrategy(BaseModel):
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ from typing import Dict, List, Optional, Tuple
|
|||
from PIL import Image as PIL_Image
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
RawContent,
|
||||
RawMediaItem,
|
||||
RawMessage,
|
||||
|
|
@ -29,6 +28,7 @@ from llama_stack.models.llama.datatypes import (
|
|||
StopReason,
|
||||
ToolCall,
|
||||
ToolPromptFormat,
|
||||
ToolType,
|
||||
)
|
||||
|
||||
from .tokenizer import Tokenizer
|
||||
|
|
@ -127,7 +127,7 @@ class ChatFormat:
|
|||
if (
|
||||
message.role == "assistant"
|
||||
and len(message.tool_calls) > 0
|
||||
and message.tool_calls[0].tool_name == BuiltinTool.code_interpreter
|
||||
and message.tool_calls[0].type == ToolType.code_interpreter
|
||||
):
|
||||
tokens.append(self.tokenizer.special_tokens["<|python_tag|>"])
|
||||
|
||||
|
|
@ -194,6 +194,7 @@ class ChatFormat:
|
|||
stop_reason = StopReason.end_of_message
|
||||
|
||||
tool_name = None
|
||||
tool_type = ToolType.function
|
||||
tool_arguments = {}
|
||||
|
||||
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
|
||||
|
|
@ -202,8 +203,8 @@ class ChatFormat:
|
|||
# Sometimes when agent has custom tools alongside builin tools
|
||||
# Agent responds for builtin tool calls in the format of the custom tools
|
||||
# This code tries to handle that case
|
||||
if tool_name in BuiltinTool.__members__:
|
||||
tool_name = BuiltinTool[tool_name]
|
||||
if tool_name in ToolType.__members__:
|
||||
tool_type = ToolType[tool_name]
|
||||
if isinstance(tool_arguments, dict):
|
||||
tool_arguments = {
|
||||
"query": list(tool_arguments.values())[0],
|
||||
|
|
@ -215,10 +216,11 @@ class ChatFormat:
|
|||
tool_arguments = {
|
||||
"query": query,
|
||||
}
|
||||
if tool_name in BuiltinTool.__members__:
|
||||
tool_name = BuiltinTool[tool_name]
|
||||
if tool_name in ToolType.__members__:
|
||||
tool_type = ToolType[tool_name]
|
||||
elif ipython:
|
||||
tool_name = BuiltinTool.code_interpreter
|
||||
tool_name = ToolType.code_interpreter.value
|
||||
tool_type = ToolType.code_interpreter
|
||||
tool_arguments = {
|
||||
"code": content,
|
||||
}
|
||||
|
|
@ -228,6 +230,7 @@ class ChatFormat:
|
|||
call_id = str(uuid.uuid4())
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type=tool_type,
|
||||
call_id=call_id,
|
||||
tool_name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from typing import List, Optional
|
|||
from termcolor import colored
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
FunctionTool,
|
||||
RawMessage,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
|
|
@ -25,7 +25,6 @@ from llama_stack.models.llama.datatypes import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
|
||||
from . import template_data
|
||||
from .chat_format import ChatFormat
|
||||
from .prompt_templates import (
|
||||
BuiltinToolGenerator,
|
||||
|
|
@ -150,8 +149,8 @@ class LLama31Interface:
|
|||
|
||||
def system_messages(
|
||||
self,
|
||||
builtin_tools: List[BuiltinTool],
|
||||
custom_tools: List[ToolDefinition],
|
||||
builtin_tools: List[ToolDefinition],
|
||||
custom_tools: List[FunctionTool],
|
||||
instruction: Optional[str] = None,
|
||||
) -> List[RawMessage]:
|
||||
messages = []
|
||||
|
|
@ -227,31 +226,3 @@ class LLama31Interface:
|
|||
on_col = on_colors[i % len(on_colors)]
|
||||
print(colored(self.tokenizer.decode([t]), "white", on_col), end="")
|
||||
print("\n", end="")
|
||||
|
||||
|
||||
def list_jinja_templates() -> List[Template]:
|
||||
return TEMPLATES
|
||||
|
||||
|
||||
def render_jinja_template(name: str, tool_prompt_format: ToolPromptFormat):
|
||||
by_name = {t.template_name: t for t in TEMPLATES}
|
||||
if name not in by_name:
|
||||
raise ValueError(f"No template found for `{name}`")
|
||||
|
||||
template = by_name[name]
|
||||
interface = LLama31Interface(tool_prompt_format)
|
||||
|
||||
data_func = getattr(template_data, template.data_provider)
|
||||
if template.role == "system":
|
||||
messages = interface.system_messages(**data_func())
|
||||
elif template.role == "tool":
|
||||
messages = interface.tool_response_messages(**data_func())
|
||||
elif template.role == "assistant":
|
||||
messages = interface.assistant_response_messages(**data_func())
|
||||
elif template.role == "user":
|
||||
messages = interface.user_message(**data_func())
|
||||
|
||||
tokens = interface.get_tokens(messages)
|
||||
special_tokens = list(interface.tokenizer.special_tokens.values())
|
||||
tokens = [(interface.tokenizer.decode([t]), t in special_tokens) for t in tokens]
|
||||
return template, tokens
|
||||
|
|
|
|||
|
|
@ -16,9 +16,13 @@ from datetime import datetime
|
|||
from typing import Any, List, Optional
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
CodeInterpreterTool,
|
||||
FunctionTool,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
ToolType,
|
||||
WebSearchTool,
|
||||
WolframAlphaTool,
|
||||
)
|
||||
|
||||
from .base import PromptTemplate, PromptTemplateGeneratorBase
|
||||
|
|
@ -47,7 +51,7 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase):
|
|||
def _tool_breakdown(self, tools: List[ToolDefinition]):
|
||||
builtin_tools, custom_tools = [], []
|
||||
for dfn in tools:
|
||||
if isinstance(dfn.tool_name, BuiltinTool):
|
||||
if dfn.type != ToolType.function.value:
|
||||
builtin_tools.append(dfn)
|
||||
else:
|
||||
custom_tools.append(dfn)
|
||||
|
|
@ -70,7 +74,11 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase):
|
|||
return PromptTemplate(
|
||||
template_str.lstrip("\n"),
|
||||
{
|
||||
"builtin_tools": [t.tool_name.value for t in builtin_tools],
|
||||
"builtin_tools": [
|
||||
# brave_search is used in training data for web_search
|
||||
t.type if t.type != ToolType.web_search.value else "brave_search"
|
||||
for t in builtin_tools
|
||||
],
|
||||
"custom_tools": custom_tools,
|
||||
},
|
||||
)
|
||||
|
|
@ -79,19 +87,19 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase):
|
|||
return [
|
||||
# builtin tools
|
||||
[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||
ToolDefinition(tool_name=BuiltinTool.wolfram_alpha),
|
||||
CodeInterpreterTool(),
|
||||
WebSearchTool(),
|
||||
WolframAlphaTool(),
|
||||
],
|
||||
# only code interpretor
|
||||
[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
CodeInterpreterTool(),
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
|
||||
def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
||||
def gen(self, custom_tools: List[FunctionTool]) -> PromptTemplate:
|
||||
template_str = textwrap.dedent(
|
||||
"""
|
||||
Answer the user's question by making use of the following functions if needed.
|
||||
|
|
@ -99,7 +107,7 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
|
|||
Here is a list of functions in JSON format:
|
||||
{% for t in custom_tools -%}
|
||||
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||
{%- set tname = t.tool_name -%}
|
||||
{%- set tname = t.name -%}
|
||||
{%- set tdesc = t.description -%}
|
||||
{%- set tparams = t.parameters -%}
|
||||
{%- set required_params = [] -%}
|
||||
|
|
@ -140,8 +148,8 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
|
|||
def data_examples(self) -> List[List[ToolDefinition]]:
|
||||
return [
|
||||
[
|
||||
ToolDefinition(
|
||||
tool_name="trending_songs",
|
||||
FunctionTool(
|
||||
name="trending_songs",
|
||||
description="Returns the trending songs on a Music site",
|
||||
parameters={
|
||||
"n": ToolParamDefinition(
|
||||
|
|
@ -161,14 +169,14 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
|
|||
|
||||
|
||||
class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
|
||||
def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
||||
def gen(self, custom_tools: List[FunctionTool]) -> PromptTemplate:
|
||||
template_str = textwrap.dedent(
|
||||
"""
|
||||
You have access to the following functions:
|
||||
|
||||
{% for t in custom_tools %}
|
||||
{#- manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||
{%- set tname = t.tool_name -%}
|
||||
{%- set tname = t.name -%}
|
||||
{%- set tdesc = t.description -%}
|
||||
{%- set modified_params = t.parameters.copy() -%}
|
||||
{%- for key, value in modified_params.items() -%}
|
||||
|
|
@ -202,8 +210,8 @@ class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
|
|||
def data_examples(self) -> List[List[ToolDefinition]]:
|
||||
return [
|
||||
[
|
||||
ToolDefinition(
|
||||
tool_name="trending_songs",
|
||||
FunctionTool(
|
||||
name="trending_songs",
|
||||
description="Returns the trending songs on a Music site",
|
||||
parameters={
|
||||
"n": ToolParamDefinition(
|
||||
|
|
@ -240,7 +248,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
{"function_description": self._gen_function_description(custom_tools)},
|
||||
)
|
||||
|
||||
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
||||
def _gen_function_description(self, custom_tools: List[FunctionTool]) -> PromptTemplate:
|
||||
template_str = textwrap.dedent(
|
||||
"""
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
|
|
@ -252,7 +260,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
[
|
||||
{% for t in tools -%}
|
||||
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||
{%- set tname = t.tool_name -%}
|
||||
{%- set tname = t.name -%}
|
||||
{%- set tdesc = t.description -%}
|
||||
{%- set tparams = t.parameters -%}
|
||||
{%- set required_params = [] -%}
|
||||
|
|
@ -289,8 +297,8 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
def data_examples(self) -> List[List[ToolDefinition]]:
|
||||
return [
|
||||
[
|
||||
ToolDefinition(
|
||||
tool_name="get_weather",
|
||||
FunctionTool(
|
||||
name="get_weather",
|
||||
description="Get weather info for places",
|
||||
parameters={
|
||||
"city": ToolParamDefinition(
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import re
|
|||
from typing import Optional, Tuple
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||
from llama_stack.models.llama.datatypes import RecursiveType, ToolCall, ToolPromptFormat, ToolType
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
|
@ -24,6 +24,12 @@ BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
|
|||
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
|
||||
|
||||
|
||||
# The model is trained with brave_search for web_search, so we need to map it
|
||||
TOOL_NAME_MAP = {
|
||||
"brave_search": ToolType.web_search.value,
|
||||
}
|
||||
|
||||
|
||||
def is_json(s):
|
||||
try:
|
||||
parsed = json.loads(s)
|
||||
|
|
@ -111,11 +117,6 @@ def parse_python_list_for_function_calls(input_string):
|
|||
|
||||
|
||||
class ToolUtils:
|
||||
@staticmethod
|
||||
def is_builtin_tool_call(message_body: str) -> bool:
|
||||
match = re.search(ToolUtils.BUILTIN_TOOL_PATTERN, message_body)
|
||||
return match is not None
|
||||
|
||||
@staticmethod
|
||||
def maybe_extract_builtin_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
|
||||
# Find the first match in the text
|
||||
|
|
@ -125,7 +126,7 @@ class ToolUtils:
|
|||
if match:
|
||||
tool_name = match.group("tool_name")
|
||||
query = match.group("query")
|
||||
return tool_name, query
|
||||
return TOOL_NAME_MAP.get(tool_name, tool_name), query
|
||||
else:
|
||||
return None
|
||||
|
||||
|
|
@ -143,7 +144,7 @@ class ToolUtils:
|
|||
tool_name = match.group("function_name")
|
||||
query = match.group("args")
|
||||
try:
|
||||
return tool_name, json.loads(query.replace("'", '"'))
|
||||
return TOOL_NAME_MAP.get(tool_name, tool_name), json.loads(query.replace("'", '"'))
|
||||
except Exception as e:
|
||||
print("Exception while parsing json query for custom tool call", query, e)
|
||||
return None
|
||||
|
|
@ -152,30 +153,28 @@ class ToolUtils:
|
|||
if ("type" in response and response["type"] == "function") or ("name" in response):
|
||||
function_name = response["name"]
|
||||
args = response["parameters"]
|
||||
return function_name, args
|
||||
return TOOL_NAME_MAP.get(function_name, function_name), args
|
||||
else:
|
||||
return None
|
||||
elif is_valid_python_list(message_body):
|
||||
res = parse_python_list_for_function_calls(message_body)
|
||||
# FIXME: Enable multiple tool calls
|
||||
return res[0]
|
||||
function_name, args = res[0]
|
||||
return TOOL_NAME_MAP.get(function_name, function_name), args
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def encode_tool_call(t: ToolCall, tool_prompt_format: ToolPromptFormat) -> str:
|
||||
if t.tool_name == BuiltinTool.brave_search:
|
||||
if t.type == ToolType.web_search:
|
||||
q = t.arguments["query"]
|
||||
return f'brave_search.call(query="{q}")'
|
||||
elif t.tool_name == BuiltinTool.wolfram_alpha:
|
||||
elif t.type == ToolType.wolfram_alpha:
|
||||
q = t.arguments["query"]
|
||||
return f'wolfram_alpha.call(query="{q}")'
|
||||
elif t.tool_name == BuiltinTool.photogen:
|
||||
q = t.arguments["query"]
|
||||
return f'photogen.call(query="{q}")'
|
||||
elif t.tool_name == BuiltinTool.code_interpreter:
|
||||
elif t.type == ToolType.code_interpreter:
|
||||
return t.arguments["code"]
|
||||
else:
|
||||
elif t.type == ToolType.function:
|
||||
fname = t.tool_name
|
||||
|
||||
if tool_prompt_format == ToolPromptFormat.json:
|
||||
|
|
@ -208,3 +207,5 @@ class ToolUtils:
|
|||
return f"[{fname}({args_str})]"
|
||||
else:
|
||||
raise ValueError(f"Unsupported tool prompt format: {tool_prompt_format}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported tool type: {t.type}")
|
||||
|
|
|
|||
|
|
@ -15,11 +15,11 @@ import textwrap
|
|||
from typing import List
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
RawMessage,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolPromptFormat,
|
||||
ToolType,
|
||||
)
|
||||
|
||||
from ..prompt_format import (
|
||||
|
|
@ -184,8 +184,9 @@ def usecases() -> List[UseCase | str]:
|
|||
stop_reason=StopReason.end_of_message,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
type=ToolType.wolfram_alpha,
|
||||
call_id="tool_call_id",
|
||||
tool_name=BuiltinTool.wolfram_alpha,
|
||||
tool_name=ToolType.wolfram_alpha.value,
|
||||
arguments={"query": "100th decimal of pi"},
|
||||
)
|
||||
],
|
||||
|
|
|
|||
|
|
@ -15,11 +15,11 @@ import textwrap
|
|||
from typing import List
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
RawMessage,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolPromptFormat,
|
||||
ToolType,
|
||||
)
|
||||
|
||||
from ..prompt_format import (
|
||||
|
|
@ -183,8 +183,9 @@ def usecases() -> List[UseCase | str]:
|
|||
stop_reason=StopReason.end_of_message,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
type=ToolType.wolfram_alpha,
|
||||
call_id="tool_call_id",
|
||||
tool_name=BuiltinTool.wolfram_alpha,
|
||||
tool_name=ToolType.wolfram_alpha.value,
|
||||
arguments={"query": "100th decimal of pi"},
|
||||
)
|
||||
],
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue