feat: RFC: tools API rework

# What does this PR do?
This PR proposes updates to the tools API in Inference and Agent.

Goals:
1. Agent's tool specification should be consistent with Inference's tool spec, but with add-ons.
2. Formal types should be defined for built in tools. Currently Agent tools args are untyped, e.g. how does one know that `builtin::rag_tool` takes a `vector_db_ids` param or even how to know 'builtin::rag_tool' is even available (in code, outside of docs)?

Inference:
1. BuiltinTool is to be removed and replaced by a formal `type` parameter.
2. 'brave_search' is replaced by 'web_search' to be more generic. It will still be translated back to brave_search when the prompt is constructed to be consistent with model training.
3. I'm not sure what `photogen` is. Maybe it can be removed?

Agent:
1. Uses the same format as in Inference for builtin tools.
2. New tools types are added, i.e. knowledge_sesarch (currently rag_tool), and MCP tool.
3. Toolgroup as a concept will be removed since it's really only used for MCP.
4. Instead MCPTool is its own type and available tools provided by the server will be expanded by default. Users can specify a subset of tool names if desired.

Example snippet:
```

agent = Agent(
    client,
    model=model_id,
    instructions="You are a helpful assistant. Use the tools you have access to for providing relevant answers.",
    tools=[
        KnowledgeSearchTool(vector_store_id="1234"),
        KnowledgeSearchTool(vector_store_id="5678", name="paper_search", description="Search research papers"),
        KnowledgeSearchTool(vector_store_id="1357", name="wiki_search", description="Search wiki pages"),
        # no need to register toolgroup, just pass in the server uri
        # all available tools will be used
        MCPTool(server_uri="http://localhost:8000/sse"),
        # can specify a subset of available tools
        MCPTool(server_uri="http://localhost:8000/sse", tool_names=["list_directory"]),
        MCPTool(server_uri="http://localhost:8000/sse", tool_names=["list_directory"]),
        # custom tool
        my_custom_tool,
    ]
)
```

## Test Plan
# What does this PR do?


## Test Plan
# What does this PR do?


## Test Plan
This commit is contained in:
Eric Huang 2025-03-26 11:14:40 -07:00
parent 39e094736f
commit 7027b537e0
22 changed files with 951 additions and 525 deletions

View file

@ -3798,6 +3798,21 @@
], ],
"title": "AppendRowsRequest" "title": "AppendRowsRequest"
}, },
"CodeInterpreterTool": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "code_interpreter",
"default": "code_interpreter"
}
},
"additionalProperties": false,
"required": [
"type"
],
"title": "CodeInterpreterTool"
},
"CompletionMessage": { "CompletionMessage": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -3837,6 +3852,34 @@
"title": "CompletionMessage", "title": "CompletionMessage",
"description": "A message containing the model's (assistant) response in a chat conversation." "description": "A message containing the model's (assistant) response in a chat conversation."
}, },
"FunctionTool": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "function",
"default": "function"
},
"name": {
"type": "string"
},
"description": {
"type": "string"
},
"parameters": {
"type": "object",
"additionalProperties": {
"$ref": "#/components/schemas/ToolParamDefinition"
}
}
},
"additionalProperties": false,
"required": [
"type",
"name"
],
"title": "FunctionTool"
},
"GrammarResponseFormat": { "GrammarResponseFormat": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -4138,25 +4181,21 @@
"ToolCall": { "ToolCall": {
"type": "object", "type": "object",
"properties": { "properties": {
"type": {
"type": "string",
"enum": [
"function",
"web_search",
"wolfram_alpha",
"code_interpreter"
],
"title": "ToolType"
},
"call_id": { "call_id": {
"type": "string" "type": "string"
}, },
"tool_name": { "tool_name": {
"oneOf": [ "type": "string"
{
"type": "string",
"enum": [
"brave_search",
"wolfram_alpha",
"photogen",
"code_interpreter"
],
"title": "BuiltinTool"
},
{
"type": "string"
}
]
}, },
"arguments": { "arguments": {
"oneOf": [ "oneOf": [
@ -4237,48 +4276,13 @@
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"type",
"call_id", "call_id",
"tool_name", "tool_name",
"arguments" "arguments"
], ],
"title": "ToolCall" "title": "ToolCall"
}, },
"ToolDefinition": {
"type": "object",
"properties": {
"tool_name": {
"oneOf": [
{
"type": "string",
"enum": [
"brave_search",
"wolfram_alpha",
"photogen",
"code_interpreter"
],
"title": "BuiltinTool"
},
{
"type": "string"
}
]
},
"description": {
"type": "string"
},
"parameters": {
"type": "object",
"additionalProperties": {
"$ref": "#/components/schemas/ToolParamDefinition"
}
}
},
"additionalProperties": false,
"required": [
"tool_name"
],
"title": "ToolDefinition"
},
"ToolParamDefinition": { "ToolParamDefinition": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -4428,6 +4432,36 @@
"title": "UserMessage", "title": "UserMessage",
"description": "A message from the user in a chat conversation." "description": "A message from the user in a chat conversation."
}, },
"WebSearchTool": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "web_search",
"default": "web_search"
}
},
"additionalProperties": false,
"required": [
"type"
],
"title": "WebSearchTool"
},
"WolframAlphaTool": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "wolfram_alpha",
"default": "wolfram_alpha"
}
},
"additionalProperties": false,
"required": [
"type"
],
"title": "WolframAlphaTool"
},
"BatchChatCompletionRequest": { "BatchChatCompletionRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -4449,7 +4483,29 @@
"tools": { "tools": {
"type": "array", "type": "array",
"items": { "items": {
"$ref": "#/components/schemas/ToolDefinition" "oneOf": [
{
"$ref": "#/components/schemas/WebSearchTool"
},
{
"$ref": "#/components/schemas/WolframAlphaTool"
},
{
"$ref": "#/components/schemas/CodeInterpreterTool"
},
{
"$ref": "#/components/schemas/FunctionTool"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"web_search": "#/components/schemas/WebSearchTool",
"wolfram_alpha": "#/components/schemas/WolframAlphaTool",
"code_interpreter": "#/components/schemas/CodeInterpreterTool",
"function": "#/components/schemas/FunctionTool"
}
}
} }
}, },
"tool_choice": { "tool_choice": {
@ -4734,6 +4790,41 @@
"title": "ToolConfig", "title": "ToolConfig",
"description": "Configuration for tool use." "description": "Configuration for tool use."
}, },
"ToolDefinitionDeprecated": {
"type": "object",
"properties": {
"tool_name": {
"oneOf": [
{
"type": "string",
"enum": [
"brave_search",
"wolfram_alpha",
"code_interpreter"
],
"title": "BuiltinTool"
},
{
"type": "string"
}
]
},
"description": {
"type": "string"
},
"parameters": {
"type": "object",
"additionalProperties": {
"$ref": "#/components/schemas/ToolParamDefinition"
}
}
},
"additionalProperties": false,
"required": [
"tool_name"
],
"title": "ToolDefinitionDeprecated"
},
"ChatCompletionRequest": { "ChatCompletionRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -4753,10 +4844,42 @@
"description": "Parameters to control the sampling strategy" "description": "Parameters to control the sampling strategy"
}, },
"tools": { "tools": {
"type": "array", "oneOf": [
"items": { {
"$ref": "#/components/schemas/ToolDefinition" "type": "array",
}, "items": {
"oneOf": [
{
"$ref": "#/components/schemas/WebSearchTool"
},
{
"$ref": "#/components/schemas/WolframAlphaTool"
},
{
"$ref": "#/components/schemas/CodeInterpreterTool"
},
{
"$ref": "#/components/schemas/FunctionTool"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"web_search": "#/components/schemas/WebSearchTool",
"wolfram_alpha": "#/components/schemas/WolframAlphaTool",
"code_interpreter": "#/components/schemas/CodeInterpreterTool",
"function": "#/components/schemas/FunctionTool"
}
}
}
},
{
"type": "array",
"items": {
"$ref": "#/components/schemas/ToolDefinitionDeprecated"
}
}
],
"description": "(Optional) List of tool definitions available to the model" "description": "(Optional) List of tool definitions available to the model"
}, },
"tool_choice": { "tool_choice": {
@ -5630,21 +5753,7 @@
"type": "string" "type": "string"
}, },
"tool_name": { "tool_name": {
"oneOf": [ "type": "string"
{
"type": "string",
"enum": [
"brave_search",
"wolfram_alpha",
"photogen",
"code_interpreter"
],
"title": "BuiltinTool"
},
{
"type": "string"
}
]
}, },
"content": { "content": {
"$ref": "#/components/schemas/InterleavedContent" "$ref": "#/components/schemas/InterleavedContent"

View file

@ -2607,6 +2607,17 @@ components:
required: required:
- rows - rows
title: AppendRowsRequest title: AppendRowsRequest
CodeInterpreterTool:
type: object
properties:
type:
type: string
const: code_interpreter
default: code_interpreter
additionalProperties: false
required:
- type
title: CodeInterpreterTool
CompletionMessage: CompletionMessage:
type: object type: object
properties: properties:
@ -2646,6 +2657,26 @@ components:
title: CompletionMessage title: CompletionMessage
description: >- description: >-
A message containing the model's (assistant) response in a chat conversation. A message containing the model's (assistant) response in a chat conversation.
FunctionTool:
type: object
properties:
type:
type: string
const: function
default: function
name:
type: string
description:
type: string
parameters:
type: object
additionalProperties:
$ref: '#/components/schemas/ToolParamDefinition'
additionalProperties: false
required:
- type
- name
title: FunctionTool
GrammarResponseFormat: GrammarResponseFormat:
type: object type: object
properties: properties:
@ -2851,18 +2882,18 @@ components:
ToolCall: ToolCall:
type: object type: object
properties: properties:
type:
type: string
enum:
- function
- web_search
- wolfram_alpha
- code_interpreter
title: ToolType
call_id: call_id:
type: string type: string
tool_name: tool_name:
oneOf: type: string
- type: string
enum:
- brave_search
- wolfram_alpha
- photogen
- code_interpreter
title: BuiltinTool
- type: string
arguments: arguments:
oneOf: oneOf:
- type: string - type: string
@ -2894,33 +2925,11 @@ components:
type: string type: string
additionalProperties: false additionalProperties: false
required: required:
- type
- call_id - call_id
- tool_name - tool_name
- arguments - arguments
title: ToolCall title: ToolCall
ToolDefinition:
type: object
properties:
tool_name:
oneOf:
- type: string
enum:
- brave_search
- wolfram_alpha
- photogen
- code_interpreter
title: BuiltinTool
- type: string
description:
type: string
parameters:
type: object
additionalProperties:
$ref: '#/components/schemas/ToolParamDefinition'
additionalProperties: false
required:
- tool_name
title: ToolDefinition
ToolParamDefinition: ToolParamDefinition:
type: object type: object
properties: properties:
@ -3031,6 +3040,28 @@ components:
title: UserMessage title: UserMessage
description: >- description: >-
A message from the user in a chat conversation. A message from the user in a chat conversation.
WebSearchTool:
type: object
properties:
type:
type: string
const: web_search
default: web_search
additionalProperties: false
required:
- type
title: WebSearchTool
WolframAlphaTool:
type: object
properties:
type:
type: string
const: wolfram_alpha
default: wolfram_alpha
additionalProperties: false
required:
- type
title: WolframAlphaTool
BatchChatCompletionRequest: BatchChatCompletionRequest:
type: object type: object
properties: properties:
@ -3047,7 +3078,18 @@ components:
tools: tools:
type: array type: array
items: items:
$ref: '#/components/schemas/ToolDefinition' oneOf:
- $ref: '#/components/schemas/WebSearchTool'
- $ref: '#/components/schemas/WolframAlphaTool'
- $ref: '#/components/schemas/CodeInterpreterTool'
- $ref: '#/components/schemas/FunctionTool'
discriminator:
propertyName: type
mapping:
web_search: '#/components/schemas/WebSearchTool'
wolfram_alpha: '#/components/schemas/WolframAlphaTool'
code_interpreter: '#/components/schemas/CodeInterpreterTool'
function: '#/components/schemas/FunctionTool'
tool_choice: tool_choice:
type: string type: string
enum: enum:
@ -3272,6 +3314,28 @@ components:
additionalProperties: false additionalProperties: false
title: ToolConfig title: ToolConfig
description: Configuration for tool use. description: Configuration for tool use.
ToolDefinitionDeprecated:
type: object
properties:
tool_name:
oneOf:
- type: string
enum:
- brave_search
- wolfram_alpha
- code_interpreter
title: BuiltinTool
- type: string
description:
type: string
parameters:
type: object
additionalProperties:
$ref: '#/components/schemas/ToolParamDefinition'
additionalProperties: false
required:
- tool_name
title: ToolDefinitionDeprecated
ChatCompletionRequest: ChatCompletionRequest:
type: object type: object
properties: properties:
@ -3290,9 +3354,24 @@ components:
description: >- description: >-
Parameters to control the sampling strategy Parameters to control the sampling strategy
tools: tools:
type: array oneOf:
items: - type: array
$ref: '#/components/schemas/ToolDefinition' items:
oneOf:
- $ref: '#/components/schemas/WebSearchTool'
- $ref: '#/components/schemas/WolframAlphaTool'
- $ref: '#/components/schemas/CodeInterpreterTool'
- $ref: '#/components/schemas/FunctionTool'
discriminator:
propertyName: type
mapping:
web_search: '#/components/schemas/WebSearchTool'
wolfram_alpha: '#/components/schemas/WolframAlphaTool'
code_interpreter: '#/components/schemas/CodeInterpreterTool'
function: '#/components/schemas/FunctionTool'
- type: array
items:
$ref: '#/components/schemas/ToolDefinitionDeprecated'
description: >- description: >-
(Optional) List of tool definitions available to the model (Optional) List of tool definitions available to the model
tool_choice: tool_choice:
@ -3939,15 +4018,7 @@ components:
call_id: call_id:
type: string type: string
tool_name: tool_name:
oneOf: type: string
- type: string
enum:
- brave_search
- wolfram_alpha
- photogen
- code_interpreter
title: BuiltinTool
- type: string
content: content:
$ref: '#/components/schemas/InterleavedContent' $ref: '#/components/schemas/InterleavedContent'
metadata: metadata:

View file

@ -17,18 +17,18 @@ from typing import (
runtime_checkable, runtime_checkable,
) )
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool,
SamplingParams, SamplingParams,
StopReason, StopReason,
ToolCall, ToolCall,
ToolDefinition, ToolDefinition,
ToolDefinitionDeprecated,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@ -156,23 +156,14 @@ Message = Annotated[
register_schema(Message, name="Message") register_schema(Message, name="Message")
# TODO: move this to agent.py where this is used
@json_schema_type @json_schema_type
class ToolResponse(BaseModel): class ToolResponse(BaseModel):
call_id: str call_id: str
tool_name: Union[BuiltinTool, str] tool_name: str
content: InterleavedContent content: InterleavedContent
metadata: Optional[Dict[str, Any]] = None metadata: Optional[Dict[str, Any]] = None
@field_validator("tool_name", mode="before")
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinTool(v)
except ValueError:
return v
return v
class ToolChoice(Enum): class ToolChoice(Enum):
"""Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model. """Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model.
@ -462,7 +453,8 @@ class Inference(Protocol):
model_id: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = None, sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None, # TODO: remove ToolDefinitionDeprecated in v0.1.10
tools: Optional[List[ToolDefinition] | List[ToolDefinitionDeprecated]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,

View file

@ -32,6 +32,7 @@ from llama_stack.apis.inference import (
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
ToolDefinitionDeprecated,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
@ -54,6 +55,9 @@ from llama_stack.apis.tools import (
) )
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import (
ToolType,
)
from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import RoutingTable
@ -229,7 +233,7 @@ class InferenceRouter(Inference):
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = None, sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition] | List[ToolDefinitionDeprecated]] = None,
tool_choice: Optional[ToolChoice] = None, tool_choice: Optional[ToolChoice] = None,
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
@ -259,24 +263,42 @@ class InferenceRouter(Inference):
params["tool_prompt_format"] = tool_prompt_format params["tool_prompt_format"] = tool_prompt_format
tool_config = ToolConfig(**params) tool_config = ToolConfig(**params)
tools = tools or [] # TODO: remove ToolDefinitionDeprecated in v0.1.10
converted_tools = []
for tool in tools or []:
if isinstance(tool, ToolDefinitionDeprecated):
logger.warning(f"ToolDefinitionDeprecated: {tool}, use ToolDefinition instead")
converted_tools.append(tool.to_tool_definition())
else:
converted_tools.append(tool)
if tool_config.tool_choice == ToolChoice.none: if tool_config.tool_choice == ToolChoice.none:
tools = [] converted_tools = []
elif tool_config.tool_choice == ToolChoice.auto: elif tool_config.tool_choice == ToolChoice.auto:
pass pass
elif tool_config.tool_choice == ToolChoice.required: elif tool_config.tool_choice == ToolChoice.required:
pass pass
else: else:
# verify tool_choice is one of the tools # verify tool_choice is one of the tools
tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools] for t in converted_tools:
if tool_config.tool_choice not in tool_names: if t.type == ToolType.function.value:
raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}") if tool_config.tool_choice == t.name:
break
elif t.type in (
ToolType.web_search.value,
ToolType.wolfram_alpha.value,
ToolType.code_interpreter.value,
):
if tool_config.tool_choice == t.type:
break
else:
raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {converted_tools}")
params = dict( params = dict(
model_id=model_id, model_id=model_id,
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools, tools=converted_tools,
tool_choice=tool_choice, tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format, tool_prompt_format=tool_prompt_format,
response_format=response_format, response_format=response_format,

View file

@ -33,10 +33,10 @@ class Role(Enum):
tool = "tool" tool = "tool"
class BuiltinTool(Enum): class ToolType(Enum):
brave_search = "brave_search" function = "function"
web_search = "web_search"
wolfram_alpha = "wolfram_alpha" wolfram_alpha = "wolfram_alpha"
photogen = "photogen"
code_interpreter = "code_interpreter" code_interpreter = "code_interpreter"
@ -45,8 +45,9 @@ RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
class ToolCall(BaseModel): class ToolCall(BaseModel):
type: ToolType
call_id: str call_id: str
tool_name: Union[BuiltinTool, str] tool_name: str
# Plan is to deprecate the Dict in favor of a JSON string # Plan is to deprecate the Dict in favor of a JSON string
# that is parsed on the client side instead of trying to manage # that is parsed on the client side instead of trying to manage
# the recursive type here. # the recursive type here.
@ -59,12 +60,18 @@ class ToolCall(BaseModel):
@field_validator("tool_name", mode="before") @field_validator("tool_name", mode="before")
@classmethod @classmethod
def validate_field(cls, v): def validate_field(cls, v):
# for backwards compatibility, we allow the tool name to be a string or a BuiltinTool
# TODO: remove ToolDefinitionDeprecated in v0.1.10
tool_name = v
if isinstance(v, str): if isinstance(v, str):
try: try:
return BuiltinTool(v) tool_name = BuiltinTool(v)
except ValueError: except ValueError:
return v pass
return v
if isinstance(tool_name, BuiltinTool):
return tool_name.to_tool().type
return tool_name
class ToolPromptFormat(Enum): class ToolPromptFormat(Enum):
@ -151,8 +158,136 @@ class ToolParamDefinition(BaseModel):
default: Optional[Any] = None default: Optional[Any] = None
class Tool(BaseModel):
type: ToolType
@classmethod
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
required_properties = ["name", "description", "parameters"]
for prop in required_properties:
has_property = any(isinstance(v, property) for v in [cls.__dict__.get(prop)])
has_field = prop in cls.__annotations__ or prop in cls.__dict__
if not has_property and not has_field:
raise TypeError(f"Class {cls.__name__} must implement '{prop}' property or field")
@json_schema_type @json_schema_type
class ToolDefinition(BaseModel): class WebSearchTool(Tool):
type: Literal[ToolType.web_search.value] = ToolType.web_search.value
@property
def name(self) -> str:
return "web_search"
@property
def description(self) -> str:
return "Search the web for information"
@property
def parameters(self) -> Dict[str, ToolParamDefinition]:
return {
"query": ToolParamDefinition(
description="The query to search for",
param_type="string",
required=True,
),
}
@json_schema_type
class WolframAlphaTool(Tool):
type: Literal[ToolType.wolfram_alpha.value] = ToolType.wolfram_alpha.value
@property
def name(self) -> str:
return "wolfram_alpha"
@property
def description(self) -> str:
return "Query WolframAlpha for computational knowledge"
@property
def parameters(self) -> Dict[str, ToolParamDefinition]:
return {
"query": ToolParamDefinition(
description="The query to compute",
param_type="string",
required=True,
),
}
@json_schema_type
class CodeInterpreterTool(Tool):
type: Literal[ToolType.code_interpreter.value] = ToolType.code_interpreter.value
@property
def name(self) -> str:
return "code_interpreter"
@property
def description(self) -> str:
return "Execute code"
@property
def parameters(self) -> Dict[str, ToolParamDefinition]:
return {
"code": ToolParamDefinition(
description="The code to execute",
param_type="string",
required=True,
),
}
@json_schema_type
class FunctionTool(Tool):
type: Literal[ToolType.function.value] = ToolType.function.value
name: str
description: Optional[str] = None
parameters: Optional[Dict[str, ToolParamDefinition]] = None
@field_validator("name", mode="before")
@classmethod
def validate_name(cls, v):
if v in ToolType.__members__:
raise ValueError(f"Tool name '{v}' is a tool type and cannot be used as a name of a function tool")
return v
ToolDefinition = Annotated[
Union[WebSearchTool, WolframAlphaTool, CodeInterpreterTool, FunctionTool], Field(discriminator="type")
]
# TODO: remove ToolDefinitionDeprecated in v0.1.10
class BuiltinTool(Enum):
brave_search = "brave_search"
wolfram_alpha = "wolfram_alpha"
code_interpreter = "code_interpreter"
def to_tool_type(self) -> ToolType:
if self == BuiltinTool.brave_search:
return ToolType.web_search
elif self == BuiltinTool.wolfram_alpha:
return ToolType.wolfram_alpha
elif self == BuiltinTool.code_interpreter:
return ToolType.code_interpreter
def to_tool(self) -> WebSearchTool | WolframAlphaTool | CodeInterpreterTool:
if self == BuiltinTool.brave_search:
return WebSearchTool()
elif self == BuiltinTool.wolfram_alpha:
return WolframAlphaTool()
elif self == BuiltinTool.code_interpreter:
return CodeInterpreterTool()
# TODO: remove ToolDefinitionDeprecated in v0.1.10
@json_schema_type
class ToolDefinitionDeprecated(BaseModel):
tool_name: Union[BuiltinTool, str] tool_name: Union[BuiltinTool, str]
description: Optional[str] = None description: Optional[str] = None
parameters: Optional[Dict[str, ToolParamDefinition]] = None parameters: Optional[Dict[str, ToolParamDefinition]] = None
@ -167,6 +302,21 @@ class ToolDefinition(BaseModel):
return v return v
return v return v
def to_tool_definition(self) -> ToolDefinition:
# convert to ToolDefinition
if self.tool_name == BuiltinTool.brave_search:
return WebSearchTool()
elif self.tool_name == BuiltinTool.code_interpreter:
return CodeInterpreterTool()
elif self.tool_name == BuiltinTool.wolfram_alpha:
return WolframAlphaTool()
else:
return FunctionTool(
name=self.tool_name,
description=self.description,
parameters=self.parameters,
)
@json_schema_type @json_schema_type
class GreedySamplingStrategy(BaseModel): class GreedySamplingStrategy(BaseModel):

View file

@ -20,7 +20,6 @@ from typing import Dict, List, Optional, Tuple
from PIL import Image as PIL_Image from PIL import Image as PIL_Image
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool,
RawContent, RawContent,
RawMediaItem, RawMediaItem,
RawMessage, RawMessage,
@ -29,6 +28,7 @@ from llama_stack.models.llama.datatypes import (
StopReason, StopReason,
ToolCall, ToolCall,
ToolPromptFormat, ToolPromptFormat,
ToolType,
) )
from .tokenizer import Tokenizer from .tokenizer import Tokenizer
@ -127,7 +127,7 @@ class ChatFormat:
if ( if (
message.role == "assistant" message.role == "assistant"
and len(message.tool_calls) > 0 and len(message.tool_calls) > 0
and message.tool_calls[0].tool_name == BuiltinTool.code_interpreter and message.tool_calls[0].type == ToolType.code_interpreter
): ):
tokens.append(self.tokenizer.special_tokens["<|python_tag|>"]) tokens.append(self.tokenizer.special_tokens["<|python_tag|>"])
@ -194,6 +194,7 @@ class ChatFormat:
stop_reason = StopReason.end_of_message stop_reason = StopReason.end_of_message
tool_name = None tool_name = None
tool_type = ToolType.function
tool_arguments = {} tool_arguments = {}
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content) custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
@ -202,8 +203,8 @@ class ChatFormat:
# Sometimes when agent has custom tools alongside builin tools # Sometimes when agent has custom tools alongside builin tools
# Agent responds for builtin tool calls in the format of the custom tools # Agent responds for builtin tool calls in the format of the custom tools
# This code tries to handle that case # This code tries to handle that case
if tool_name in BuiltinTool.__members__: if tool_name in ToolType.__members__:
tool_name = BuiltinTool[tool_name] tool_type = ToolType[tool_name]
if isinstance(tool_arguments, dict): if isinstance(tool_arguments, dict):
tool_arguments = { tool_arguments = {
"query": list(tool_arguments.values())[0], "query": list(tool_arguments.values())[0],
@ -215,10 +216,11 @@ class ChatFormat:
tool_arguments = { tool_arguments = {
"query": query, "query": query,
} }
if tool_name in BuiltinTool.__members__: if tool_name in ToolType.__members__:
tool_name = BuiltinTool[tool_name] tool_type = ToolType[tool_name]
elif ipython: elif ipython:
tool_name = BuiltinTool.code_interpreter tool_name = ToolType.code_interpreter.value
tool_type = ToolType.code_interpreter
tool_arguments = { tool_arguments = {
"code": content, "code": content,
} }
@ -228,6 +230,7 @@ class ChatFormat:
call_id = str(uuid.uuid4()) call_id = str(uuid.uuid4())
tool_calls.append( tool_calls.append(
ToolCall( ToolCall(
type=tool_type,
call_id=call_id, call_id=call_id,
tool_name=tool_name, tool_name=tool_name,
arguments=tool_arguments, arguments=tool_arguments,

View file

@ -17,7 +17,7 @@ from typing import List, Optional
from termcolor import colored from termcolor import colored
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, FunctionTool,
RawMessage, RawMessage,
StopReason, StopReason,
ToolCall, ToolCall,
@ -25,7 +25,6 @@ from llama_stack.models.llama.datatypes import (
ToolPromptFormat, ToolPromptFormat,
) )
from . import template_data
from .chat_format import ChatFormat from .chat_format import ChatFormat
from .prompt_templates import ( from .prompt_templates import (
BuiltinToolGenerator, BuiltinToolGenerator,
@ -150,8 +149,8 @@ class LLama31Interface:
def system_messages( def system_messages(
self, self,
builtin_tools: List[BuiltinTool], builtin_tools: List[ToolDefinition],
custom_tools: List[ToolDefinition], custom_tools: List[FunctionTool],
instruction: Optional[str] = None, instruction: Optional[str] = None,
) -> List[RawMessage]: ) -> List[RawMessage]:
messages = [] messages = []
@ -227,31 +226,3 @@ class LLama31Interface:
on_col = on_colors[i % len(on_colors)] on_col = on_colors[i % len(on_colors)]
print(colored(self.tokenizer.decode([t]), "white", on_col), end="") print(colored(self.tokenizer.decode([t]), "white", on_col), end="")
print("\n", end="") print("\n", end="")
def list_jinja_templates() -> List[Template]:
return TEMPLATES
def render_jinja_template(name: str, tool_prompt_format: ToolPromptFormat):
by_name = {t.template_name: t for t in TEMPLATES}
if name not in by_name:
raise ValueError(f"No template found for `{name}`")
template = by_name[name]
interface = LLama31Interface(tool_prompt_format)
data_func = getattr(template_data, template.data_provider)
if template.role == "system":
messages = interface.system_messages(**data_func())
elif template.role == "tool":
messages = interface.tool_response_messages(**data_func())
elif template.role == "assistant":
messages = interface.assistant_response_messages(**data_func())
elif template.role == "user":
messages = interface.user_message(**data_func())
tokens = interface.get_tokens(messages)
special_tokens = list(interface.tokenizer.special_tokens.values())
tokens = [(interface.tokenizer.decode([t]), t in special_tokens) for t in tokens]
return template, tokens

View file

@ -16,9 +16,13 @@ from datetime import datetime
from typing import Any, List, Optional from typing import Any, List, Optional
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, CodeInterpreterTool,
FunctionTool,
ToolDefinition, ToolDefinition,
ToolParamDefinition, ToolParamDefinition,
ToolType,
WebSearchTool,
WolframAlphaTool,
) )
from .base import PromptTemplate, PromptTemplateGeneratorBase from .base import PromptTemplate, PromptTemplateGeneratorBase
@ -47,7 +51,7 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase):
def _tool_breakdown(self, tools: List[ToolDefinition]): def _tool_breakdown(self, tools: List[ToolDefinition]):
builtin_tools, custom_tools = [], [] builtin_tools, custom_tools = [], []
for dfn in tools: for dfn in tools:
if isinstance(dfn.tool_name, BuiltinTool): if dfn.type != ToolType.function.value:
builtin_tools.append(dfn) builtin_tools.append(dfn)
else: else:
custom_tools.append(dfn) custom_tools.append(dfn)
@ -70,7 +74,11 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase):
return PromptTemplate( return PromptTemplate(
template_str.lstrip("\n"), template_str.lstrip("\n"),
{ {
"builtin_tools": [t.tool_name.value for t in builtin_tools], "builtin_tools": [
# brave_search is used in training data for web_search
t.type if t.type != ToolType.web_search.value else "brave_search"
for t in builtin_tools
],
"custom_tools": custom_tools, "custom_tools": custom_tools,
}, },
) )
@ -79,19 +87,19 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase):
return [ return [
# builtin tools # builtin tools
[ [
ToolDefinition(tool_name=BuiltinTool.code_interpreter), CodeInterpreterTool(),
ToolDefinition(tool_name=BuiltinTool.brave_search), WebSearchTool(),
ToolDefinition(tool_name=BuiltinTool.wolfram_alpha), WolframAlphaTool(),
], ],
# only code interpretor # only code interpretor
[ [
ToolDefinition(tool_name=BuiltinTool.code_interpreter), CodeInterpreterTool(),
], ],
] ]
class JsonCustomToolGenerator(PromptTemplateGeneratorBase): class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: def gen(self, custom_tools: List[FunctionTool]) -> PromptTemplate:
template_str = textwrap.dedent( template_str = textwrap.dedent(
""" """
Answer the user's question by making use of the following functions if needed. Answer the user's question by making use of the following functions if needed.
@ -99,7 +107,7 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
Here is a list of functions in JSON format: Here is a list of functions in JSON format:
{% for t in custom_tools -%} {% for t in custom_tools -%}
{# manually setting up JSON because jinja sorts keys in unexpected ways -#} {# manually setting up JSON because jinja sorts keys in unexpected ways -#}
{%- set tname = t.tool_name -%} {%- set tname = t.name -%}
{%- set tdesc = t.description -%} {%- set tdesc = t.description -%}
{%- set tparams = t.parameters -%} {%- set tparams = t.parameters -%}
{%- set required_params = [] -%} {%- set required_params = [] -%}
@ -140,8 +148,8 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
def data_examples(self) -> List[List[ToolDefinition]]: def data_examples(self) -> List[List[ToolDefinition]]:
return [ return [
[ [
ToolDefinition( FunctionTool(
tool_name="trending_songs", name="trending_songs",
description="Returns the trending songs on a Music site", description="Returns the trending songs on a Music site",
parameters={ parameters={
"n": ToolParamDefinition( "n": ToolParamDefinition(
@ -161,14 +169,14 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase): class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: def gen(self, custom_tools: List[FunctionTool]) -> PromptTemplate:
template_str = textwrap.dedent( template_str = textwrap.dedent(
""" """
You have access to the following functions: You have access to the following functions:
{% for t in custom_tools %} {% for t in custom_tools %}
{#- manually setting up JSON because jinja sorts keys in unexpected ways -#} {#- manually setting up JSON because jinja sorts keys in unexpected ways -#}
{%- set tname = t.tool_name -%} {%- set tname = t.name -%}
{%- set tdesc = t.description -%} {%- set tdesc = t.description -%}
{%- set modified_params = t.parameters.copy() -%} {%- set modified_params = t.parameters.copy() -%}
{%- for key, value in modified_params.items() -%} {%- for key, value in modified_params.items() -%}
@ -202,8 +210,8 @@ class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
def data_examples(self) -> List[List[ToolDefinition]]: def data_examples(self) -> List[List[ToolDefinition]]:
return [ return [
[ [
ToolDefinition( FunctionTool(
tool_name="trending_songs", name="trending_songs",
description="Returns the trending songs on a Music site", description="Returns the trending songs on a Music site",
parameters={ parameters={
"n": ToolParamDefinition( "n": ToolParamDefinition(
@ -240,7 +248,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
{"function_description": self._gen_function_description(custom_tools)}, {"function_description": self._gen_function_description(custom_tools)},
) )
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: def _gen_function_description(self, custom_tools: List[FunctionTool]) -> PromptTemplate:
template_str = textwrap.dedent( template_str = textwrap.dedent(
""" """
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
@ -252,7 +260,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
[ [
{% for t in tools -%} {% for t in tools -%}
{# manually setting up JSON because jinja sorts keys in unexpected ways -#} {# manually setting up JSON because jinja sorts keys in unexpected ways -#}
{%- set tname = t.tool_name -%} {%- set tname = t.name -%}
{%- set tdesc = t.description -%} {%- set tdesc = t.description -%}
{%- set tparams = t.parameters -%} {%- set tparams = t.parameters -%}
{%- set required_params = [] -%} {%- set required_params = [] -%}
@ -289,8 +297,8 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
def data_examples(self) -> List[List[ToolDefinition]]: def data_examples(self) -> List[List[ToolDefinition]]:
return [ return [
[ [
ToolDefinition( FunctionTool(
tool_name="get_weather", name="get_weather",
description="Get weather info for places", description="Get weather info for places",
parameters={ parameters={
"city": ToolParamDefinition( "city": ToolParamDefinition(

View file

@ -16,7 +16,7 @@ import re
from typing import Optional, Tuple from typing import Optional, Tuple
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat from llama_stack.models.llama.datatypes import RecursiveType, ToolCall, ToolPromptFormat, ToolType
logger = get_logger(name=__name__, category="inference") logger = get_logger(name=__name__, category="inference")
@ -24,6 +24,12 @@ BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})") CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
# The model is trained with brave_search for web_search, so we need to map it
TOOL_NAME_MAP = {
"brave_search": ToolType.web_search.value,
}
def is_json(s): def is_json(s):
try: try:
parsed = json.loads(s) parsed = json.loads(s)
@ -111,11 +117,6 @@ def parse_python_list_for_function_calls(input_string):
class ToolUtils: class ToolUtils:
@staticmethod
def is_builtin_tool_call(message_body: str) -> bool:
match = re.search(ToolUtils.BUILTIN_TOOL_PATTERN, message_body)
return match is not None
@staticmethod @staticmethod
def maybe_extract_builtin_tool_call(message_body: str) -> Optional[Tuple[str, str]]: def maybe_extract_builtin_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
# Find the first match in the text # Find the first match in the text
@ -125,7 +126,7 @@ class ToolUtils:
if match: if match:
tool_name = match.group("tool_name") tool_name = match.group("tool_name")
query = match.group("query") query = match.group("query")
return tool_name, query return TOOL_NAME_MAP.get(tool_name, tool_name), query
else: else:
return None return None
@ -143,7 +144,7 @@ class ToolUtils:
tool_name = match.group("function_name") tool_name = match.group("function_name")
query = match.group("args") query = match.group("args")
try: try:
return tool_name, json.loads(query.replace("'", '"')) return TOOL_NAME_MAP.get(tool_name, tool_name), json.loads(query.replace("'", '"'))
except Exception as e: except Exception as e:
print("Exception while parsing json query for custom tool call", query, e) print("Exception while parsing json query for custom tool call", query, e)
return None return None
@ -152,30 +153,28 @@ class ToolUtils:
if ("type" in response and response["type"] == "function") or ("name" in response): if ("type" in response and response["type"] == "function") or ("name" in response):
function_name = response["name"] function_name = response["name"]
args = response["parameters"] args = response["parameters"]
return function_name, args return TOOL_NAME_MAP.get(function_name, function_name), args
else: else:
return None return None
elif is_valid_python_list(message_body): elif is_valid_python_list(message_body):
res = parse_python_list_for_function_calls(message_body) res = parse_python_list_for_function_calls(message_body)
# FIXME: Enable multiple tool calls # FIXME: Enable multiple tool calls
return res[0] function_name, args = res[0]
return TOOL_NAME_MAP.get(function_name, function_name), args
else: else:
return None return None
@staticmethod @staticmethod
def encode_tool_call(t: ToolCall, tool_prompt_format: ToolPromptFormat) -> str: def encode_tool_call(t: ToolCall, tool_prompt_format: ToolPromptFormat) -> str:
if t.tool_name == BuiltinTool.brave_search: if t.type == ToolType.web_search:
q = t.arguments["query"] q = t.arguments["query"]
return f'brave_search.call(query="{q}")' return f'brave_search.call(query="{q}")'
elif t.tool_name == BuiltinTool.wolfram_alpha: elif t.type == ToolType.wolfram_alpha:
q = t.arguments["query"] q = t.arguments["query"]
return f'wolfram_alpha.call(query="{q}")' return f'wolfram_alpha.call(query="{q}")'
elif t.tool_name == BuiltinTool.photogen: elif t.type == ToolType.code_interpreter:
q = t.arguments["query"]
return f'photogen.call(query="{q}")'
elif t.tool_name == BuiltinTool.code_interpreter:
return t.arguments["code"] return t.arguments["code"]
else: elif t.type == ToolType.function:
fname = t.tool_name fname = t.tool_name
if tool_prompt_format == ToolPromptFormat.json: if tool_prompt_format == ToolPromptFormat.json:
@ -208,3 +207,5 @@ class ToolUtils:
return f"[{fname}({args_str})]" return f"[{fname}({args_str})]"
else: else:
raise ValueError(f"Unsupported tool prompt format: {tool_prompt_format}") raise ValueError(f"Unsupported tool prompt format: {tool_prompt_format}")
else:
raise ValueError(f"Unsupported tool type: {t.type}")

View file

@ -15,11 +15,11 @@ import textwrap
from typing import List from typing import List
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool,
RawMessage, RawMessage,
StopReason, StopReason,
ToolCall, ToolCall,
ToolPromptFormat, ToolPromptFormat,
ToolType,
) )
from ..prompt_format import ( from ..prompt_format import (
@ -184,8 +184,9 @@ def usecases() -> List[UseCase | str]:
stop_reason=StopReason.end_of_message, stop_reason=StopReason.end_of_message,
tool_calls=[ tool_calls=[
ToolCall( ToolCall(
type=ToolType.wolfram_alpha,
call_id="tool_call_id", call_id="tool_call_id",
tool_name=BuiltinTool.wolfram_alpha, tool_name=ToolType.wolfram_alpha.value,
arguments={"query": "100th decimal of pi"}, arguments={"query": "100th decimal of pi"},
) )
], ],

View file

@ -15,11 +15,11 @@ import textwrap
from typing import List from typing import List
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool,
RawMessage, RawMessage,
StopReason, StopReason,
ToolCall, ToolCall,
ToolPromptFormat, ToolPromptFormat,
ToolType,
) )
from ..prompt_format import ( from ..prompt_format import (
@ -183,8 +183,9 @@ def usecases() -> List[UseCase | str]:
stop_reason=StopReason.end_of_message, stop_reason=StopReason.end_of_message,
tool_calls=[ tool_calls=[
ToolCall( ToolCall(
type=ToolType.wolfram_alpha,
call_id="tool_call_id", call_id="tool_call_id",
tool_name=BuiltinTool.wolfram_alpha, tool_name=ToolType.wolfram_alpha.value,
arguments={"query": "100th decimal of pi"}, arguments={"query": "100th decimal of pi"},
) )
], ],

View file

@ -53,7 +53,7 @@ from llama_stack.apis.inference import (
SamplingParams, SamplingParams,
StopReason, StopReason,
SystemMessage, SystemMessage,
ToolDefinition, ToolDefinitionDeprecated,
ToolResponse, ToolResponse,
ToolResponseMessage, ToolResponseMessage,
UserMessage, UserMessage,
@ -771,7 +771,7 @@ class ChatAgent(ShieldRunnerMixin):
for tool_def in self.agent_config.client_tools: for tool_def in self.agent_config.client_tools:
if tool_name_to_def.get(tool_def.name, None): if tool_name_to_def.get(tool_def.name, None):
raise ValueError(f"Tool {tool_def.name} already exists") raise ValueError(f"Tool {tool_def.name} already exists")
tool_name_to_def[tool_def.name] = ToolDefinition( tool_name_to_def[tool_def.name] = ToolDefinitionDeprecated(
tool_name=tool_def.name, tool_name=tool_def.name,
description=tool_def.description, description=tool_def.description,
parameters={ parameters={
@ -814,7 +814,7 @@ class ChatAgent(ShieldRunnerMixin):
if tool_name_to_def.get(identifier, None): if tool_name_to_def.get(identifier, None):
raise ValueError(f"Tool {identifier} already exists") raise ValueError(f"Tool {identifier} already exists")
if identifier: if identifier:
tool_name_to_def[tool_def.identifier] = ToolDefinition( tool_name_to_def[tool_def.identifier] = ToolDefinitionDeprecated(
tool_name=identifier, tool_name=identifier,
description=tool_def.description, description=tool_def.description,
parameters={ parameters={
@ -854,30 +854,23 @@ class ChatAgent(ShieldRunnerMixin):
tool_call: ToolCall, tool_call: ToolCall,
) -> ToolInvocationResult: ) -> ToolInvocationResult:
tool_name = tool_call.tool_name tool_name = tool_call.tool_name
registered_tool_names = [tool_def.tool_name for tool_def in self.tool_defs] registered_tool_names = list(self.tool_name_to_args.keys())
if tool_name not in registered_tool_names: if tool_name not in registered_tool_names:
raise ValueError( raise ValueError(
f"Tool {tool_name} not found in provided tools, registered tools: {', '.join([str(x) for x in registered_tool_names])}" f"Tool {tool_name} not found in provided tools, registered tools: {', '.join([str(x) for x in registered_tool_names])}"
) )
if isinstance(tool_name, BuiltinTool):
if tool_name == BuiltinTool.brave_search:
tool_name_str = WEB_SEARCH_TOOL
else:
tool_name_str = tool_name.value
else:
tool_name_str = tool_name
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}") logger.info(f"executing tool call: {tool_name} with args: {tool_call.arguments}")
result = await self.tool_runtime_api.invoke_tool( result = await self.tool_runtime_api.invoke_tool(
tool_name=tool_name_str, tool_name=tool_name,
kwargs={ kwargs={
"session_id": session_id, "session_id": session_id,
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent # get the arguments generated by the model and augment with toolgroup arg overrides for the agent
**tool_call.arguments, **tool_call.arguments,
**self.tool_name_to_args.get(tool_name_str, {}), **self.tool_name_to_args.get(tool_name, {}),
}, },
) )
logger.debug(f"tool call {tool_name_str} completed with result: {result}") logger.debug(f"tool call {tool_name} completed with result: {result}")
return result return result
async def handle_documents( async def handle_documents(

View file

@ -16,7 +16,7 @@ from llama_stack.apis.inference import (
ToolChoice, ToolChoice,
UserMessage, UserMessage,
) )
from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition from llama_stack.models.llama.datatypes import ToolDefinition, ToolType
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict, convert_message_to_openai_dict,
get_sampling_options, get_sampling_options,
@ -65,7 +65,7 @@ def _llama_stack_tools_to_openai_tools(
result = [] result = []
for t in tools: for t in tools:
if isinstance(t.tool_name, BuiltinTool): if t.type != ToolType.function.value:
raise NotImplementedError("Built-in tools not yet implemented") raise NotImplementedError("Built-in tools not yet implemented")
if t.parameters is None: if t.parameters is None:
parameters = None parameters = None

View file

@ -45,7 +45,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall from llama_stack.models.llama.datatypes import StopReason, ToolCall
from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
@ -110,6 +110,8 @@ def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]
for tool in tools: for tool in tools:
properties = {} properties = {}
compat_required = [] compat_required = []
tool_name = tool.name
if tool.parameters: if tool.parameters:
for tool_key, tool_param in tool.parameters.items(): for tool_key, tool_param in tool.parameters.items():
properties[tool_key] = {"type": tool_param.param_type} properties[tool_key] = {"type": tool_param.param_type}
@ -120,12 +122,6 @@ def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]
if tool_param.required: if tool_param.required:
compat_required.append(tool_key) compat_required.append(tool_key)
# The tool.tool_name can be a str or a BuiltinTool enum. If
# it's the latter, convert to a string.
tool_name = tool.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
compat_tool = { compat_tool = {
"type": "function", "type": "function",
"function": { "function": {

View file

@ -17,7 +17,6 @@ from llama_stack.apis.tools import (
ToolRuntime, ToolRuntime,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import BraveSearchToolConfig from .config import BraveSearchToolConfig
@ -61,7 +60,6 @@ class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequest
parameter_type="string", parameter_type="string",
) )
], ],
built_in_type=BuiltinTool.brave_search,
) )
] ]

View file

@ -80,12 +80,12 @@ from llama_stack.apis.inference import (
UserMessage, UserMessage,
) )
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool,
GreedySamplingStrategy, GreedySamplingStrategy,
SamplingParams, SamplingParams,
StopReason, StopReason,
ToolCall, ToolCall,
ToolDefinition, ToolDefinition,
ToolType,
TopKSamplingStrategy, TopKSamplingStrategy,
TopPSamplingStrategy, TopPSamplingStrategy,
) )
@ -271,7 +271,7 @@ def process_chat_completion_response(
else: else:
# only return tool_calls if provided in the request # only return tool_calls if provided in the request
new_tool_calls = [] new_tool_calls = []
request_tools = {t.tool_name: t for t in request.tools} request_tools = {t.name: t for t in request.tools}
for t in raw_message.tool_calls: for t in raw_message.tool_calls:
if t.tool_name in request_tools: if t.tool_name in request_tools:
new_tool_calls.append(t) new_tool_calls.append(t)
@ -423,7 +423,7 @@ async def process_chat_completion_stream_response(
) )
) )
request_tools = {t.tool_name: t for t in request.tools} request_tools = {t.name: t for t in request.tools}
for tool_call in message.tool_calls: for tool_call in message.tool_calls:
if tool_call.tool_name in request_tools: if tool_call.tool_name in request_tools:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
@ -574,7 +574,7 @@ async def convert_message_to_openai_dict_new(
OpenAIChatCompletionMessageToolCall( OpenAIChatCompletionMessageToolCall(
id=tool.call_id, id=tool.call_id,
function=OpenAIFunction( function=OpenAIFunction(
name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value), name=tool.tool_name,
arguments=json.dumps(tool.arguments), arguments=json.dumps(tool.arguments),
), ),
type="function", type="function",
@ -638,7 +638,7 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
Convert a ToolDefinition to an OpenAI API-compatible dictionary. Convert a ToolDefinition to an OpenAI API-compatible dictionary.
ToolDefinition: ToolDefinition:
tool_name: str | BuiltinTool tool_name: str
description: Optional[str] description: Optional[str]
parameters: Optional[Dict[str, ToolParamDefinition]] parameters: Optional[Dict[str, ToolParamDefinition]]
@ -677,10 +677,7 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
} }
function = out["function"] function = out["function"]
if isinstance(tool.tool_name, BuiltinTool): function.update(name=tool.name)
function.update(name=tool.tool_name.value) # TODO(mf): is this sufficient?
else:
function.update(name=tool.tool_name)
if tool.description: if tool.description:
function.update(description=tool.description) function.update(description=tool.description)
@ -761,6 +758,7 @@ def _convert_openai_tool_calls(
return [ return [
ToolCall( ToolCall(
type=ToolType.function,
call_id=call.id, call_id=call.id,
tool_name=call.function.name, tool_name=call.function.name,
arguments=json.loads(call.function.arguments), arguments=json.loads(call.function.arguments),
@ -975,6 +973,7 @@ async def convert_openai_chat_completion_stream(
try: try:
arguments = json.loads(buffer["arguments"]) arguments = json.loads(buffer["arguments"])
tool_call = ToolCall( tool_call = ToolCall(
type=ToolType.function,
call_id=buffer["call_id"], call_id=buffer["call_id"],
tool_name=buffer["name"], tool_name=buffer["name"],
arguments=arguments, arguments=arguments,

View file

@ -43,6 +43,7 @@ from llama_stack.models.llama.datatypes import (
Role, Role,
StopReason, StopReason,
ToolPromptFormat, ToolPromptFormat,
ToolType,
is_multimodal, is_multimodal,
) )
from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.chat_format import ChatFormat
@ -374,8 +375,8 @@ def augment_messages_for_tools_llama_3_1(
messages.append(SystemMessage(content=sys_content)) messages.append(SystemMessage(content=sys_content))
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools) custom_tools = [t for t in request.tools if t.type == ToolType.function.value]
if has_custom_tools: if custom_tools:
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json
if fmt == ToolPromptFormat.json: if fmt == ToolPromptFormat.json:
tool_gen = JsonCustomToolGenerator() tool_gen = JsonCustomToolGenerator()
@ -384,7 +385,6 @@ def augment_messages_for_tools_llama_3_1(
else: else:
raise ValueError(f"Non supported ToolPromptFormat {fmt}") raise ValueError(f"Non supported ToolPromptFormat {fmt}")
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
custom_template = tool_gen.gen(custom_tools) custom_template = tool_gen.gen(custom_tools)
messages.append(UserMessage(content=custom_template.render())) messages.append(UserMessage(content=custom_template.render()))
@ -407,7 +407,7 @@ def augment_messages_for_tools_llama_3_2(
sys_content = "" sys_content = ""
custom_tools, builtin_tools = [], [] custom_tools, builtin_tools = [], []
for t in request.tools: for t in request.tools:
if isinstance(t.tool_name, str): if t.type == ToolType.function.value:
custom_tools.append(t) custom_tools.append(t)
else: else:
builtin_tools.append(t) builtin_tools.append(t)
@ -419,7 +419,7 @@ def augment_messages_for_tools_llama_3_2(
sys_content += tool_template.render() sys_content += tool_template.render()
sys_content += "\n" sys_content += "\n"
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)] custom_tools = [t for t in request.tools if t.type == ToolType.function.value]
if custom_tools: if custom_tools:
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list
if fmt != ToolPromptFormat.python_list: if fmt != ToolPromptFormat.python_list:

View file

@ -188,16 +188,12 @@ def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent
} }
], ],
session_id=session_id, session_id=session_id,
stream=False,
) )
logs = [str(log) for log in AgentEventLogger().log(response) if log is not None] tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
logs_str = "".join(logs) assert tool_execution_step.tool_calls[0].tool_name == "web_search"
assert "mark zuckerberg" in response.output_message.content.lower()
assert "tool_execution>" in logs_str
assert "Tool:brave_search Response:" in logs_str
assert "mark zuckerberg" in logs_str.lower()
if len(agent_config["output_shields"]) > 0:
assert "No Violation" in logs_str
def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, agent_config): def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, agent_config):

View file

@ -284,6 +284,7 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_
"test_case", "test_case",
[ [
"inference:chat_completion:tool_calling", "inference:chat_completion:tool_calling",
"inference:chat_completion:tool_calling_deprecated",
], ],
) )
def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_models, text_model_id, test_case): def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_models, text_model_id, test_case):
@ -300,7 +301,9 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_mo
assert response.completion_message.role == "assistant" assert response.completion_message.role == "assistant"
assert len(response.completion_message.tool_calls) == 1 assert len(response.completion_message.tool_calls) == 1
assert response.completion_message.tool_calls[0].tool_name == tc["tools"][0]["tool_name"] assert response.completion_message.tool_calls[0].tool_name == (
tc["tools"][0]["tool_name"] if "tool_name" in tc["tools"][0] else tc["tools"][0]["name"]
)
assert response.completion_message.tool_calls[0].arguments == tc["expected"] assert response.completion_message.tool_calls[0].arguments == tc["expected"]
@ -334,7 +337,7 @@ def test_text_chat_completion_with_tool_calling_and_streaming(client_with_models
stream=True, stream=True,
) )
tool_invocation_content = extract_tool_invocation_content(response) tool_invocation_content = extract_tool_invocation_content(response)
expected_tool_name = tc["tools"][0]["tool_name"] expected_tool_name = tc["tools"][0]["tool_name"] if "tool_name" in tc["tools"][0] else tc["tools"][0]["name"]
expected_argument = tc["expected"] expected_argument = tc["expected"]
assert tool_invocation_content == f"[{expected_tool_name}, {expected_argument}]" assert tool_invocation_content == f"[{expected_tool_name}, {expected_argument}]"
@ -358,7 +361,7 @@ def test_text_chat_completion_with_tool_choice_required(client_with_models, text
stream=True, stream=True,
) )
tool_invocation_content = extract_tool_invocation_content(response) tool_invocation_content = extract_tool_invocation_content(response)
expected_tool_name = tc["tools"][0]["tool_name"] expected_tool_name = tc["tools"][0]["tool_name"] if "tool_name" in tc["tools"][0] else tc["tools"][0]["name"]
expected_argument = tc["expected"] expected_argument = tc["expected"]
assert tool_invocation_content == f"[{expected_tool_name}, {expected_argument}]" assert tool_invocation_content == f"[{expected_tool_name}, {expected_argument}]"
@ -432,14 +435,11 @@ def test_text_chat_completion_tool_calling_tools_not_in_request(
): ):
tc = TestCase(test_case) tc = TestCase(test_case)
# TODO: more dynamic lookup on tool_prompt_format for model family
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
request = { request = {
"model_id": text_model_id, "model_id": text_model_id,
"messages": tc["messages"], "messages": tc["messages"],
"tools": tc["tools"], "tools": tc["tools"],
"tool_choice": "auto", "tool_choice": "auto",
"tool_prompt_format": tool_prompt_format,
"stream": streaming, "stream": streaming,
} }
@ -457,3 +457,30 @@ def test_text_chat_completion_tool_calling_tools_not_in_request(
else: else:
for tc in response.completion_message.tool_calls: for tc in response.completion_message.tool_calls:
assert tc.tool_name == "get_object_namespace_list" assert tc.tool_name == "get_object_namespace_list"
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:tool_calling_builtin_web_search",
"inference:chat_completion:tool_calling_builtin_brave_search",
"inference:chat_completion:tool_calling_builtin_code_interpreter",
"inference:chat_completion:tool_calling_builtin_code_interpreter_deprecated",
],
)
def test_text_chat_completion_tool_calling_builtin(client_with_models, text_model_id, test_case):
tc = TestCase(test_case)
request = {
"model_id": text_model_id,
"messages": tc["messages"],
"tools": tc["tools"],
"tool_choice": "auto",
"stream": False,
}
response = client_with_models.inference.chat_completion(**request)
for tool_call in response.completion_message.tool_calls:
print(tool_call)
assert tool_call.tool_name == tc["expected"]

View file

@ -49,7 +49,7 @@
"expected": "Washington" "expected": "Washington"
} }
}, },
"tool_calling": { "tool_calling_deprecated": {
"data": { "data": {
"messages": [ "messages": [
{"role": "system", "content": "Pretend you are a weather assistant."}, {"role": "system", "content": "Pretend you are a weather assistant."},
@ -72,6 +72,30 @@
} }
} }
}, },
"tool_calling": {
"data": {
"messages": [
{"role": "system", "content": "Pretend you are a weather assistant."},
{"role": "user", "content": "What's the weather like in San Francisco?"}
],
"tools": [
{
"type": "function",
"name": "get_weather",
"description": "Get the current weather",
"parameters": {
"location": {
"param_type": "string",
"description": "The city and state (both required), e.g. San Francisco, CA."
}
}
}
],
"expected": {
"location": "San Francisco, CA"
}
}
},
"sample_messages_tool_calling": { "sample_messages_tool_calling": {
"data": { "data": {
"messages": [ "messages": [
@ -128,6 +152,56 @@
} }
} }
}, },
"tool_calling_builtin_web_search": {
"data": {
"messages": [
{"role": "system", "content": "You are a helpful assistant. Use available tools to answer the question."},
{"role": "user", "content": "What's the weather like in San Francisco?"}
],
"tools": [
{
"type": "web_search"
}
],
"expected": "web_search"
}
},
"tool_calling_builtin_brave_search": {
"data": {
"messages": [
{"role": "system", "content": "You are a helpful assistant. Use available tools to answer the question."},
{"role": "user", "content": "What's the weather like in San Francisco?"}
],
"tools": [{
"tool_name": "brave_search"
}],
"expected": "web_search"
}
},
"tool_calling_builtin_code_interpreter_deprecated": {
"data": {
"messages": [
{"role": "system", "content": "You are a helpful assistant. Use available tools to answer the question."},
{"role": "user", "content": "plot log(x) from -10 to 10"}
],
"tools": [{
"tool_name": "code_interpreter"
}],
"expected": "code_interpreter"
}
},
"tool_calling_builtin_code_interpreter": {
"data": {
"messages": [
{"role": "system", "content": "You are a helpful assistant. Use available tools to answer the question."},
{"role": "user", "content": "plot log(x) from -10 to 10"}
],
"tools": [{
"type": "code_interpreter"
}],
"expected": "code_interpreter"
}
},
"tool_calling_tools_absent": { "tool_calling_tools_absent": {
"data": { "data": {
"messages": [ "messages": [
@ -146,6 +220,7 @@
"tool_calls": [ "tool_calls": [
{ {
"call_id": "1", "call_id": "1",
"type": "function",
"tool_name": "get_object_namespace_list", "tool_name": "get_object_namespace_list",
"arguments": { "arguments": {
"kind": "pod", "kind": "pod",

View file

@ -5,7 +5,8 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import unittest
import pytest
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
@ -17,10 +18,11 @@ from llama_stack.apis.inference import (
UserMessage, UserMessage,
) )
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, CodeInterpreterTool,
ToolDefinition, FunctionTool,
ToolParamDefinition, ToolParamDefinition,
ToolPromptFormat, ToolPromptFormat,
WebSearchTool,
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages, chat_completion_request_to_messages,
@ -31,258 +33,269 @@ MODEL = "Llama3.1-8B-Instruct"
MODEL3_2 = "Llama3.2-3B-Instruct" MODEL3_2 = "Llama3.2-3B-Instruct"
class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): @pytest.fixture(autouse=True)
async def asyncSetUp(self): def setup_loop():
asyncio.get_running_loop().set_debug(False) loop = asyncio.get_event_loop()
loop.set_debug(False)
return loop
async def test_system_default(self):
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
async def test_system_builtin_only(self): @pytest.mark.asyncio
content = "Hello !" async def test_system_default():
request = ChatCompletionRequest( content = "Hello !"
model=MODEL, request = ChatCompletionRequest(
messages=[ model=MODEL,
UserMessage(content=content), messages=[
], UserMessage(content=content),
tools=[ ],
ToolDefinition(tool_name=BuiltinTool.code_interpreter), )
ToolDefinition(tool_name=BuiltinTool.brave_search), messages = chat_completion_request_to_messages(request, MODEL)
], assert len(messages) == 2
) assert messages[-1].content == content
messages = chat_completion_request_to_messages(request, MODEL) assert "Cutting Knowledge Date: December 2023" in messages[0].content
self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
self.assertTrue("Tools: brave_search" in messages[0].content)
async def test_system_custom_only(self):
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
)
],
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content)
self.assertTrue("Return function calls in JSON format" in messages[1].content) @pytest.mark.asyncio
self.assertEqual(messages[-1].content, content) async def test_system_builtin_only():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
CodeInterpreterTool(),
WebSearchTool(),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 2
assert messages[-1].content == content
assert "Cutting Knowledge Date: December 2023" in messages[0].content
assert "Tools: brave_search" in messages[0].content
async def test_system_custom_and_builtin(self):
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(tool_name=BuiltinTool.brave_search),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content) @pytest.mark.asyncio
self.assertTrue("Tools: brave_search" in messages[0].content) async def test_system_custom_only():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
FunctionTool(
name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
)
],
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 3
assert "Environment: ipython" in messages[0].content
assert "Return function calls in JSON format" in messages[1].content
assert messages[-1].content == content
self.assertTrue("Return function calls in JSON format" in messages[1].content)
self.assertEqual(messages[-1].content, content)
async def test_completion_message_encoding(self): @pytest.mark.asyncio
request = ChatCompletionRequest( async def test_system_custom_and_builtin():
model=MODEL3_2, content = "Hello !"
messages=[ request = ChatCompletionRequest(
UserMessage(content="hello"), model=MODEL,
CompletionMessage( messages=[
content="", UserMessage(content=content),
stop_reason=StopReason.end_of_turn, ],
tool_calls=[ tools=[
ToolCall( CodeInterpreterTool(),
tool_name="custom1", WebSearchTool(),
arguments={"param1": "value1"}, FunctionTool(
call_id="123", name="custom1",
) description="custom1 tool",
], parameters={
), "param1": ToolParamDefinition(
], param_type="str",
tools=[ description="param1 description",
ToolDefinition( required=True,
tool_name="custom1", ),
description="custom1 tool", },
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
),
],
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
)
prompt = await chat_completion_request_to_prompt(request, request.model)
self.assertIn('[custom1(param1="value1")]', prompt)
request.model = MODEL
request.tool_config.tool_prompt_format = ToolPromptFormat.json
prompt = await chat_completion_request_to_prompt(request, request.model)
self.assertIn(
'{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}',
prompt,
)
async def test_user_provided_system_message(self):
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2, messages)
self.assertTrue(messages[0].content.endswith(system_prompt))
self.assertEqual(messages[-1].content, content)
async def test_repalce_system_message_behavior_builtin_tools(self):
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format="python_list",
system_message_behavior="replace",
), ),
) ],
messages = chat_completion_request_to_messages(request, MODEL3_2) )
self.assertEqual(len(messages), 2, messages) messages = chat_completion_request_to_messages(request, MODEL)
self.assertTrue(messages[0].content.endswith(system_prompt)) assert len(messages) == 3
self.assertIn("Environment: ipython", messages[0].content) assert "Environment: ipython" in messages[0].content
self.assertEqual(messages[-1].content, content) assert "Tools: brave_search" in messages[0].content
assert "Return function calls in JSON format" in messages[1].content
assert messages[-1].content == content
async def test_repalce_system_message_behavior_custom_tools(self):
content = "Hello !" @pytest.mark.asyncio
system_prompt = "You are a pirate" async def test_completion_message_encoding():
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=MODEL, model=MODEL3_2,
messages=[ messages=[
SystemMessage(content=system_prompt), UserMessage(content="hello"),
UserMessage(content=content), CompletionMessage(
], content="",
tools=[ stop_reason=StopReason.end_of_turn,
ToolDefinition(tool_name=BuiltinTool.code_interpreter), tool_calls=[
ToolDefinition( ToolCall(
tool_name="custom1", type="function",
description="custom1 tool", tool_name="custom1",
parameters={ arguments={"param1": "value1"},
"param1": ToolParamDefinition( call_id="123",
param_type="str", )
description="param1 description", ],
required=True,
),
},
),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format="python_list",
system_message_behavior="replace",
), ),
) ],
messages = chat_completion_request_to_messages(request, MODEL3_2) tools=[
FunctionTool(
self.assertEqual(len(messages), 2, messages) name="custom1",
self.assertTrue(messages[0].content.endswith(system_prompt)) description="custom1 tool",
self.assertIn("Environment: ipython", messages[0].content) parameters={
self.assertEqual(messages[-1].content, content) "param1": ToolParamDefinition(
param_type="str",
async def test_replace_system_message_behavior_custom_tools_with_template(self): description="param1 description",
content = "Hello !" required=True,
system_prompt = "You are a pirate {{ function_description }}" ),
request = ChatCompletionRequest( },
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format="python_list",
system_message_behavior="replace",
), ),
) ],
messages = chat_completion_request_to_messages(request, MODEL3_2) tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
)
prompt = await chat_completion_request_to_prompt(request, request.model)
assert '[custom1(param1="value1")]' in prompt
self.assertEqual(len(messages), 2, messages) request.model = MODEL
self.assertIn("Environment: ipython", messages[0].content) request.tool_config.tool_prompt_format = ToolPromptFormat.json
self.assertIn("You are a pirate", messages[0].content) prompt = await chat_completion_request_to_prompt(request, request.model)
# function description is present in the system prompt assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt
self.assertIn('"name": "custom1"', messages[0].content)
self.assertEqual(messages[-1].content, content)
@pytest.mark.asyncio
async def test_user_provided_system_message():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
CodeInterpreterTool(),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 2
assert messages[0].content.endswith(system_prompt)
assert messages[-1].content == content
@pytest.mark.asyncio
async def test_repalce_system_message_behavior_builtin_tools():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
CodeInterpreterTool(),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format="python_list",
system_message_behavior="replace",
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
assert len(messages) == 2
assert messages[0].content.endswith(system_prompt)
assert "Environment: ipython" in messages[0].content
assert messages[-1].content == content
@pytest.mark.asyncio
async def test_repalce_system_message_behavior_custom_tools():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
CodeInterpreterTool(),
FunctionTool(
name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format="python_list",
system_message_behavior="replace",
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
assert len(messages) == 2
assert messages[0].content.endswith(system_prompt)
assert "Environment: ipython" in messages[0].content
assert messages[-1].content == content
@pytest.mark.asyncio
async def test_replace_system_message_behavior_custom_tools_with_template():
content = "Hello !"
system_prompt = "You are a pirate {{ function_description }}"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
CodeInterpreterTool(),
FunctionTool(
name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format="python_list",
system_message_behavior="replace",
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
assert len(messages) == 2
assert "Environment: ipython" in messages[0].content
assert "You are a pirate" in messages[0].content
assert '"name": "custom1"' in messages[0].content
assert messages[-1].content == content

View file

@ -33,7 +33,7 @@ class PromptTemplateTests(unittest.TestCase):
if not example: if not example:
continue continue
for tool in example: for tool in example:
assert tool.tool_name in text assert tool.name in text
def test_system_default(self): def test_system_default(self):
generator = SystemDefaultGenerator() generator = SystemDefaultGenerator()