diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index 8a46a89ad..2274f86e7 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -3798,6 +3798,21 @@
],
"title": "AppendRowsRequest"
},
+ "CodeInterpreterTool": {
+ "type": "object",
+ "properties": {
+ "type": {
+ "type": "string",
+ "const": "code_interpreter",
+ "default": "code_interpreter"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type"
+ ],
+ "title": "CodeInterpreterTool"
+ },
"CompletionMessage": {
"type": "object",
"properties": {
@@ -3837,6 +3852,34 @@
"title": "CompletionMessage",
"description": "A message containing the model's (assistant) response in a chat conversation."
},
+ "FunctionTool": {
+ "type": "object",
+ "properties": {
+ "type": {
+ "type": "string",
+ "const": "function",
+ "default": "function"
+ },
+ "name": {
+ "type": "string"
+ },
+ "description": {
+ "type": "string"
+ },
+ "parameters": {
+ "type": "object",
+ "additionalProperties": {
+ "$ref": "#/components/schemas/ToolParamDefinition"
+ }
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type",
+ "name"
+ ],
+ "title": "FunctionTool"
+ },
"GrammarResponseFormat": {
"type": "object",
"properties": {
@@ -4138,25 +4181,21 @@
"ToolCall": {
"type": "object",
"properties": {
+ "type": {
+ "type": "string",
+ "enum": [
+ "function",
+ "web_search",
+ "wolfram_alpha",
+ "code_interpreter"
+ ],
+ "title": "ToolType"
+ },
"call_id": {
"type": "string"
},
"tool_name": {
- "oneOf": [
- {
- "type": "string",
- "enum": [
- "brave_search",
- "wolfram_alpha",
- "photogen",
- "code_interpreter"
- ],
- "title": "BuiltinTool"
- },
- {
- "type": "string"
- }
- ]
+ "type": "string"
},
"arguments": {
"oneOf": [
@@ -4237,48 +4276,13 @@
},
"additionalProperties": false,
"required": [
+ "type",
"call_id",
"tool_name",
"arguments"
],
"title": "ToolCall"
},
- "ToolDefinition": {
- "type": "object",
- "properties": {
- "tool_name": {
- "oneOf": [
- {
- "type": "string",
- "enum": [
- "brave_search",
- "wolfram_alpha",
- "photogen",
- "code_interpreter"
- ],
- "title": "BuiltinTool"
- },
- {
- "type": "string"
- }
- ]
- },
- "description": {
- "type": "string"
- },
- "parameters": {
- "type": "object",
- "additionalProperties": {
- "$ref": "#/components/schemas/ToolParamDefinition"
- }
- }
- },
- "additionalProperties": false,
- "required": [
- "tool_name"
- ],
- "title": "ToolDefinition"
- },
"ToolParamDefinition": {
"type": "object",
"properties": {
@@ -4428,6 +4432,36 @@
"title": "UserMessage",
"description": "A message from the user in a chat conversation."
},
+ "WebSearchTool": {
+ "type": "object",
+ "properties": {
+ "type": {
+ "type": "string",
+ "const": "web_search",
+ "default": "web_search"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type"
+ ],
+ "title": "WebSearchTool"
+ },
+ "WolframAlphaTool": {
+ "type": "object",
+ "properties": {
+ "type": {
+ "type": "string",
+ "const": "wolfram_alpha",
+ "default": "wolfram_alpha"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type"
+ ],
+ "title": "WolframAlphaTool"
+ },
"BatchChatCompletionRequest": {
"type": "object",
"properties": {
@@ -4449,7 +4483,29 @@
"tools": {
"type": "array",
"items": {
- "$ref": "#/components/schemas/ToolDefinition"
+ "oneOf": [
+ {
+ "$ref": "#/components/schemas/WebSearchTool"
+ },
+ {
+ "$ref": "#/components/schemas/WolframAlphaTool"
+ },
+ {
+ "$ref": "#/components/schemas/CodeInterpreterTool"
+ },
+ {
+ "$ref": "#/components/schemas/FunctionTool"
+ }
+ ],
+ "discriminator": {
+ "propertyName": "type",
+ "mapping": {
+ "web_search": "#/components/schemas/WebSearchTool",
+ "wolfram_alpha": "#/components/schemas/WolframAlphaTool",
+ "code_interpreter": "#/components/schemas/CodeInterpreterTool",
+ "function": "#/components/schemas/FunctionTool"
+ }
+ }
}
},
"tool_choice": {
@@ -4734,6 +4790,41 @@
"title": "ToolConfig",
"description": "Configuration for tool use."
},
+ "ToolDefinitionDeprecated": {
+ "type": "object",
+ "properties": {
+ "tool_name": {
+ "oneOf": [
+ {
+ "type": "string",
+ "enum": [
+ "brave_search",
+ "wolfram_alpha",
+ "code_interpreter"
+ ],
+ "title": "BuiltinTool"
+ },
+ {
+ "type": "string"
+ }
+ ]
+ },
+ "description": {
+ "type": "string"
+ },
+ "parameters": {
+ "type": "object",
+ "additionalProperties": {
+ "$ref": "#/components/schemas/ToolParamDefinition"
+ }
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "tool_name"
+ ],
+ "title": "ToolDefinitionDeprecated"
+ },
"ChatCompletionRequest": {
"type": "object",
"properties": {
@@ -4753,10 +4844,42 @@
"description": "Parameters to control the sampling strategy"
},
"tools": {
- "type": "array",
- "items": {
- "$ref": "#/components/schemas/ToolDefinition"
- },
+ "oneOf": [
+ {
+ "type": "array",
+ "items": {
+ "oneOf": [
+ {
+ "$ref": "#/components/schemas/WebSearchTool"
+ },
+ {
+ "$ref": "#/components/schemas/WolframAlphaTool"
+ },
+ {
+ "$ref": "#/components/schemas/CodeInterpreterTool"
+ },
+ {
+ "$ref": "#/components/schemas/FunctionTool"
+ }
+ ],
+ "discriminator": {
+ "propertyName": "type",
+ "mapping": {
+ "web_search": "#/components/schemas/WebSearchTool",
+ "wolfram_alpha": "#/components/schemas/WolframAlphaTool",
+ "code_interpreter": "#/components/schemas/CodeInterpreterTool",
+ "function": "#/components/schemas/FunctionTool"
+ }
+ }
+ }
+ },
+ {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/ToolDefinitionDeprecated"
+ }
+ }
+ ],
"description": "(Optional) List of tool definitions available to the model"
},
"tool_choice": {
@@ -5630,21 +5753,7 @@
"type": "string"
},
"tool_name": {
- "oneOf": [
- {
- "type": "string",
- "enum": [
- "brave_search",
- "wolfram_alpha",
- "photogen",
- "code_interpreter"
- ],
- "title": "BuiltinTool"
- },
- {
- "type": "string"
- }
- ]
+ "type": "string"
},
"content": {
"$ref": "#/components/schemas/InterleavedContent"
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 0b8f90490..d1517bd1e 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -2607,6 +2607,17 @@ components:
required:
- rows
title: AppendRowsRequest
+ CodeInterpreterTool:
+ type: object
+ properties:
+ type:
+ type: string
+ const: code_interpreter
+ default: code_interpreter
+ additionalProperties: false
+ required:
+ - type
+ title: CodeInterpreterTool
CompletionMessage:
type: object
properties:
@@ -2646,6 +2657,26 @@ components:
title: CompletionMessage
description: >-
A message containing the model's (assistant) response in a chat conversation.
+ FunctionTool:
+ type: object
+ properties:
+ type:
+ type: string
+ const: function
+ default: function
+ name:
+ type: string
+ description:
+ type: string
+ parameters:
+ type: object
+ additionalProperties:
+ $ref: '#/components/schemas/ToolParamDefinition'
+ additionalProperties: false
+ required:
+ - type
+ - name
+ title: FunctionTool
GrammarResponseFormat:
type: object
properties:
@@ -2851,18 +2882,18 @@ components:
ToolCall:
type: object
properties:
+ type:
+ type: string
+ enum:
+ - function
+ - web_search
+ - wolfram_alpha
+ - code_interpreter
+ title: ToolType
call_id:
type: string
tool_name:
- oneOf:
- - type: string
- enum:
- - brave_search
- - wolfram_alpha
- - photogen
- - code_interpreter
- title: BuiltinTool
- - type: string
+ type: string
arguments:
oneOf:
- type: string
@@ -2894,33 +2925,11 @@ components:
type: string
additionalProperties: false
required:
+ - type
- call_id
- tool_name
- arguments
title: ToolCall
- ToolDefinition:
- type: object
- properties:
- tool_name:
- oneOf:
- - type: string
- enum:
- - brave_search
- - wolfram_alpha
- - photogen
- - code_interpreter
- title: BuiltinTool
- - type: string
- description:
- type: string
- parameters:
- type: object
- additionalProperties:
- $ref: '#/components/schemas/ToolParamDefinition'
- additionalProperties: false
- required:
- - tool_name
- title: ToolDefinition
ToolParamDefinition:
type: object
properties:
@@ -3031,6 +3040,28 @@ components:
title: UserMessage
description: >-
A message from the user in a chat conversation.
+ WebSearchTool:
+ type: object
+ properties:
+ type:
+ type: string
+ const: web_search
+ default: web_search
+ additionalProperties: false
+ required:
+ - type
+ title: WebSearchTool
+ WolframAlphaTool:
+ type: object
+ properties:
+ type:
+ type: string
+ const: wolfram_alpha
+ default: wolfram_alpha
+ additionalProperties: false
+ required:
+ - type
+ title: WolframAlphaTool
BatchChatCompletionRequest:
type: object
properties:
@@ -3047,7 +3078,18 @@ components:
tools:
type: array
items:
- $ref: '#/components/schemas/ToolDefinition'
+ oneOf:
+ - $ref: '#/components/schemas/WebSearchTool'
+ - $ref: '#/components/schemas/WolframAlphaTool'
+ - $ref: '#/components/schemas/CodeInterpreterTool'
+ - $ref: '#/components/schemas/FunctionTool'
+ discriminator:
+ propertyName: type
+ mapping:
+ web_search: '#/components/schemas/WebSearchTool'
+ wolfram_alpha: '#/components/schemas/WolframAlphaTool'
+ code_interpreter: '#/components/schemas/CodeInterpreterTool'
+ function: '#/components/schemas/FunctionTool'
tool_choice:
type: string
enum:
@@ -3272,6 +3314,28 @@ components:
additionalProperties: false
title: ToolConfig
description: Configuration for tool use.
+ ToolDefinitionDeprecated:
+ type: object
+ properties:
+ tool_name:
+ oneOf:
+ - type: string
+ enum:
+ - brave_search
+ - wolfram_alpha
+ - code_interpreter
+ title: BuiltinTool
+ - type: string
+ description:
+ type: string
+ parameters:
+ type: object
+ additionalProperties:
+ $ref: '#/components/schemas/ToolParamDefinition'
+ additionalProperties: false
+ required:
+ - tool_name
+ title: ToolDefinitionDeprecated
ChatCompletionRequest:
type: object
properties:
@@ -3290,9 +3354,24 @@ components:
description: >-
Parameters to control the sampling strategy
tools:
- type: array
- items:
- $ref: '#/components/schemas/ToolDefinition'
+ oneOf:
+ - type: array
+ items:
+ oneOf:
+ - $ref: '#/components/schemas/WebSearchTool'
+ - $ref: '#/components/schemas/WolframAlphaTool'
+ - $ref: '#/components/schemas/CodeInterpreterTool'
+ - $ref: '#/components/schemas/FunctionTool'
+ discriminator:
+ propertyName: type
+ mapping:
+ web_search: '#/components/schemas/WebSearchTool'
+ wolfram_alpha: '#/components/schemas/WolframAlphaTool'
+ code_interpreter: '#/components/schemas/CodeInterpreterTool'
+ function: '#/components/schemas/FunctionTool'
+ - type: array
+ items:
+ $ref: '#/components/schemas/ToolDefinitionDeprecated'
description: >-
(Optional) List of tool definitions available to the model
tool_choice:
@@ -3939,15 +4018,7 @@ components:
call_id:
type: string
tool_name:
- oneOf:
- - type: string
- enum:
- - brave_search
- - wolfram_alpha
- - photogen
- - code_interpreter
- title: BuiltinTool
- - type: string
+ type: string
content:
$ref: '#/components/schemas/InterleavedContent'
metadata:
diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py
index 7d3539dcb..c51f4b971 100644
--- a/llama_stack/apis/inference/inference.py
+++ b/llama_stack/apis/inference/inference.py
@@ -17,18 +17,18 @@ from typing import (
runtime_checkable,
)
-from pydantic import BaseModel, Field, field_validator
+from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
from llama_stack.apis.models import Model
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
from llama_stack.models.llama.datatypes import (
- BuiltinTool,
SamplingParams,
StopReason,
ToolCall,
ToolDefinition,
+ ToolDefinitionDeprecated,
ToolPromptFormat,
)
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@@ -156,23 +156,14 @@ Message = Annotated[
register_schema(Message, name="Message")
+# TODO: move this to agent.py where this is used
@json_schema_type
class ToolResponse(BaseModel):
call_id: str
- tool_name: Union[BuiltinTool, str]
+ tool_name: str
content: InterleavedContent
metadata: Optional[Dict[str, Any]] = None
- @field_validator("tool_name", mode="before")
- @classmethod
- def validate_field(cls, v):
- if isinstance(v, str):
- try:
- return BuiltinTool(v)
- except ValueError:
- return v
- return v
-
class ToolChoice(Enum):
"""Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model.
@@ -462,7 +453,8 @@ class Inference(Protocol):
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
- tools: Optional[List[ToolDefinition]] = None,
+ # TODO: remove ToolDefinitionDeprecated in v0.1.10
+ tools: Optional[List[ToolDefinition] | List[ToolDefinitionDeprecated]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py
index 6ff36a65c..764fa2406 100644
--- a/llama_stack/distribution/routers/routers.py
+++ b/llama_stack/distribution/routers/routers.py
@@ -32,6 +32,7 @@ from llama_stack.apis.inference import (
ToolChoice,
ToolConfig,
ToolDefinition,
+ ToolDefinitionDeprecated,
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
@@ -54,6 +55,9 @@ from llama_stack.apis.tools import (
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.log import get_logger
+from llama_stack.models.llama.datatypes import (
+ ToolType,
+)
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import RoutingTable
@@ -229,7 +233,7 @@ class InferenceRouter(Inference):
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
- tools: Optional[List[ToolDefinition]] = None,
+ tools: Optional[List[ToolDefinition] | List[ToolDefinitionDeprecated]] = None,
tool_choice: Optional[ToolChoice] = None,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
@@ -259,24 +263,42 @@ class InferenceRouter(Inference):
params["tool_prompt_format"] = tool_prompt_format
tool_config = ToolConfig(**params)
- tools = tools or []
+ # TODO: remove ToolDefinitionDeprecated in v0.1.10
+ converted_tools = []
+ for tool in tools or []:
+ if isinstance(tool, ToolDefinitionDeprecated):
+ logger.warning(f"ToolDefinitionDeprecated: {tool}, use ToolDefinition instead")
+ converted_tools.append(tool.to_tool_definition())
+ else:
+ converted_tools.append(tool)
+
if tool_config.tool_choice == ToolChoice.none:
- tools = []
+ converted_tools = []
elif tool_config.tool_choice == ToolChoice.auto:
pass
elif tool_config.tool_choice == ToolChoice.required:
pass
else:
# verify tool_choice is one of the tools
- tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
- if tool_config.tool_choice not in tool_names:
- raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
+ for t in converted_tools:
+ if t.type == ToolType.function.value:
+ if tool_config.tool_choice == t.name:
+ break
+ elif t.type in (
+ ToolType.web_search.value,
+ ToolType.wolfram_alpha.value,
+ ToolType.code_interpreter.value,
+ ):
+ if tool_config.tool_choice == t.type:
+ break
+ else:
+ raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {converted_tools}")
params = dict(
model_id=model_id,
messages=messages,
sampling_params=sampling_params,
- tools=tools,
+ tools=converted_tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
diff --git a/llama_stack/models/llama/datatypes.py b/llama_stack/models/llama/datatypes.py
index f762eb50f..f1c280f32 100644
--- a/llama_stack/models/llama/datatypes.py
+++ b/llama_stack/models/llama/datatypes.py
@@ -33,10 +33,10 @@ class Role(Enum):
tool = "tool"
-class BuiltinTool(Enum):
- brave_search = "brave_search"
+class ToolType(Enum):
+ function = "function"
+ web_search = "web_search"
wolfram_alpha = "wolfram_alpha"
- photogen = "photogen"
code_interpreter = "code_interpreter"
@@ -45,8 +45,9 @@ RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
class ToolCall(BaseModel):
+ type: ToolType
call_id: str
- tool_name: Union[BuiltinTool, str]
+ tool_name: str
# Plan is to deprecate the Dict in favor of a JSON string
# that is parsed on the client side instead of trying to manage
# the recursive type here.
@@ -59,12 +60,18 @@ class ToolCall(BaseModel):
@field_validator("tool_name", mode="before")
@classmethod
def validate_field(cls, v):
+ # for backwards compatibility, we allow the tool name to be a string or a BuiltinTool
+ # TODO: remove ToolDefinitionDeprecated in v0.1.10
+ tool_name = v
if isinstance(v, str):
try:
- return BuiltinTool(v)
+ tool_name = BuiltinTool(v)
except ValueError:
- return v
- return v
+ pass
+
+ if isinstance(tool_name, BuiltinTool):
+ return tool_name.to_tool().type
+ return tool_name
class ToolPromptFormat(Enum):
@@ -151,8 +158,136 @@ class ToolParamDefinition(BaseModel):
default: Optional[Any] = None
+class Tool(BaseModel):
+ type: ToolType
+
+ @classmethod
+ def __init_subclass__(cls, **kwargs):
+ super().__init_subclass__(**kwargs)
+
+ required_properties = ["name", "description", "parameters"]
+ for prop in required_properties:
+ has_property = any(isinstance(v, property) for v in [cls.__dict__.get(prop)])
+ has_field = prop in cls.__annotations__ or prop in cls.__dict__
+ if not has_property and not has_field:
+ raise TypeError(f"Class {cls.__name__} must implement '{prop}' property or field")
+
+
@json_schema_type
-class ToolDefinition(BaseModel):
+class WebSearchTool(Tool):
+ type: Literal[ToolType.web_search.value] = ToolType.web_search.value
+
+ @property
+ def name(self) -> str:
+ return "web_search"
+
+ @property
+ def description(self) -> str:
+ return "Search the web for information"
+
+ @property
+ def parameters(self) -> Dict[str, ToolParamDefinition]:
+ return {
+ "query": ToolParamDefinition(
+ description="The query to search for",
+ param_type="string",
+ required=True,
+ ),
+ }
+
+
+@json_schema_type
+class WolframAlphaTool(Tool):
+ type: Literal[ToolType.wolfram_alpha.value] = ToolType.wolfram_alpha.value
+
+ @property
+ def name(self) -> str:
+ return "wolfram_alpha"
+
+ @property
+ def description(self) -> str:
+ return "Query WolframAlpha for computational knowledge"
+
+ @property
+ def parameters(self) -> Dict[str, ToolParamDefinition]:
+ return {
+ "query": ToolParamDefinition(
+ description="The query to compute",
+ param_type="string",
+ required=True,
+ ),
+ }
+
+
+@json_schema_type
+class CodeInterpreterTool(Tool):
+ type: Literal[ToolType.code_interpreter.value] = ToolType.code_interpreter.value
+
+ @property
+ def name(self) -> str:
+ return "code_interpreter"
+
+ @property
+ def description(self) -> str:
+ return "Execute code"
+
+ @property
+ def parameters(self) -> Dict[str, ToolParamDefinition]:
+ return {
+ "code": ToolParamDefinition(
+ description="The code to execute",
+ param_type="string",
+ required=True,
+ ),
+ }
+
+
+@json_schema_type
+class FunctionTool(Tool):
+ type: Literal[ToolType.function.value] = ToolType.function.value
+ name: str
+ description: Optional[str] = None
+ parameters: Optional[Dict[str, ToolParamDefinition]] = None
+
+ @field_validator("name", mode="before")
+ @classmethod
+ def validate_name(cls, v):
+ if v in ToolType.__members__:
+ raise ValueError(f"Tool name '{v}' is a tool type and cannot be used as a name of a function tool")
+ return v
+
+
+ToolDefinition = Annotated[
+ Union[WebSearchTool, WolframAlphaTool, CodeInterpreterTool, FunctionTool], Field(discriminator="type")
+]
+
+
+# TODO: remove ToolDefinitionDeprecated in v0.1.10
+class BuiltinTool(Enum):
+ brave_search = "brave_search"
+ wolfram_alpha = "wolfram_alpha"
+ code_interpreter = "code_interpreter"
+
+ def to_tool_type(self) -> ToolType:
+ if self == BuiltinTool.brave_search:
+ return ToolType.web_search
+ elif self == BuiltinTool.wolfram_alpha:
+ return ToolType.wolfram_alpha
+ elif self == BuiltinTool.code_interpreter:
+ return ToolType.code_interpreter
+
+ def to_tool(self) -> WebSearchTool | WolframAlphaTool | CodeInterpreterTool:
+ if self == BuiltinTool.brave_search:
+ return WebSearchTool()
+ elif self == BuiltinTool.wolfram_alpha:
+ return WolframAlphaTool()
+ elif self == BuiltinTool.code_interpreter:
+ return CodeInterpreterTool()
+
+
+# TODO: remove ToolDefinitionDeprecated in v0.1.10
+@json_schema_type
+class ToolDefinitionDeprecated(BaseModel):
tool_name: Union[BuiltinTool, str]
description: Optional[str] = None
parameters: Optional[Dict[str, ToolParamDefinition]] = None
@@ -167,6 +302,21 @@ class ToolDefinition(BaseModel):
return v
return v
+ def to_tool_definition(self) -> ToolDefinition:
+ # convert to ToolDefinition
+ if self.tool_name == BuiltinTool.brave_search:
+ return WebSearchTool()
+ elif self.tool_name == BuiltinTool.code_interpreter:
+ return CodeInterpreterTool()
+ elif self.tool_name == BuiltinTool.wolfram_alpha:
+ return WolframAlphaTool()
+ else:
+ return FunctionTool(
+ name=self.tool_name,
+ description=self.description,
+ parameters=self.parameters,
+ )
+
@json_schema_type
class GreedySamplingStrategy(BaseModel):
diff --git a/llama_stack/models/llama/llama3/chat_format.py b/llama_stack/models/llama/llama3/chat_format.py
index 2862f8558..e1c44cff0 100644
--- a/llama_stack/models/llama/llama3/chat_format.py
+++ b/llama_stack/models/llama/llama3/chat_format.py
@@ -20,7 +20,6 @@ from typing import Dict, List, Optional, Tuple
from PIL import Image as PIL_Image
from llama_stack.models.llama.datatypes import (
- BuiltinTool,
RawContent,
RawMediaItem,
RawMessage,
@@ -29,6 +28,7 @@ from llama_stack.models.llama.datatypes import (
StopReason,
ToolCall,
ToolPromptFormat,
+ ToolType,
)
from .tokenizer import Tokenizer
@@ -127,7 +127,7 @@ class ChatFormat:
if (
message.role == "assistant"
and len(message.tool_calls) > 0
- and message.tool_calls[0].tool_name == BuiltinTool.code_interpreter
+ and message.tool_calls[0].type == ToolType.code_interpreter
):
tokens.append(self.tokenizer.special_tokens["<|python_tag|>"])
@@ -194,6 +194,7 @@ class ChatFormat:
stop_reason = StopReason.end_of_message
tool_name = None
+ tool_type = ToolType.function
tool_arguments = {}
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
@@ -202,8 +203,8 @@ class ChatFormat:
# Sometimes when agent has custom tools alongside builin tools
# Agent responds for builtin tool calls in the format of the custom tools
# This code tries to handle that case
- if tool_name in BuiltinTool.__members__:
- tool_name = BuiltinTool[tool_name]
+ if tool_name in ToolType.__members__:
+ tool_type = ToolType[tool_name]
if isinstance(tool_arguments, dict):
tool_arguments = {
"query": list(tool_arguments.values())[0],
@@ -215,10 +216,11 @@ class ChatFormat:
tool_arguments = {
"query": query,
}
- if tool_name in BuiltinTool.__members__:
- tool_name = BuiltinTool[tool_name]
+ if tool_name in ToolType.__members__:
+ tool_type = ToolType[tool_name]
elif ipython:
- tool_name = BuiltinTool.code_interpreter
+ tool_name = ToolType.code_interpreter.value
+ tool_type = ToolType.code_interpreter
tool_arguments = {
"code": content,
}
@@ -228,6 +230,7 @@ class ChatFormat:
call_id = str(uuid.uuid4())
tool_calls.append(
ToolCall(
+ type=tool_type,
call_id=call_id,
tool_name=tool_name,
arguments=tool_arguments,
diff --git a/llama_stack/models/llama/llama3/interface.py b/llama_stack/models/llama/llama3/interface.py
index 2579ab6c8..5abace9e2 100644
--- a/llama_stack/models/llama/llama3/interface.py
+++ b/llama_stack/models/llama/llama3/interface.py
@@ -17,7 +17,7 @@ from typing import List, Optional
from termcolor import colored
from llama_stack.models.llama.datatypes import (
- BuiltinTool,
+ FunctionTool,
RawMessage,
StopReason,
ToolCall,
@@ -25,7 +25,6 @@ from llama_stack.models.llama.datatypes import (
ToolPromptFormat,
)
-from . import template_data
from .chat_format import ChatFormat
from .prompt_templates import (
BuiltinToolGenerator,
@@ -150,8 +149,8 @@ class LLama31Interface:
def system_messages(
self,
- builtin_tools: List[BuiltinTool],
- custom_tools: List[ToolDefinition],
+ builtin_tools: List[ToolDefinition],
+ custom_tools: List[FunctionTool],
instruction: Optional[str] = None,
) -> List[RawMessage]:
messages = []
@@ -227,31 +226,3 @@ class LLama31Interface:
on_col = on_colors[i % len(on_colors)]
print(colored(self.tokenizer.decode([t]), "white", on_col), end="")
print("\n", end="")
-
-
-def list_jinja_templates() -> List[Template]:
- return TEMPLATES
-
-
-def render_jinja_template(name: str, tool_prompt_format: ToolPromptFormat):
- by_name = {t.template_name: t for t in TEMPLATES}
- if name not in by_name:
- raise ValueError(f"No template found for `{name}`")
-
- template = by_name[name]
- interface = LLama31Interface(tool_prompt_format)
-
- data_func = getattr(template_data, template.data_provider)
- if template.role == "system":
- messages = interface.system_messages(**data_func())
- elif template.role == "tool":
- messages = interface.tool_response_messages(**data_func())
- elif template.role == "assistant":
- messages = interface.assistant_response_messages(**data_func())
- elif template.role == "user":
- messages = interface.user_message(**data_func())
-
- tokens = interface.get_tokens(messages)
- special_tokens = list(interface.tokenizer.special_tokens.values())
- tokens = [(interface.tokenizer.decode([t]), t in special_tokens) for t in tokens]
- return template, tokens
diff --git a/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py b/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py
index 9da6a640e..2cd69f5d0 100644
--- a/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py
+++ b/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py
@@ -16,9 +16,13 @@ from datetime import datetime
from typing import Any, List, Optional
from llama_stack.models.llama.datatypes import (
- BuiltinTool,
+ CodeInterpreterTool,
+ FunctionTool,
ToolDefinition,
ToolParamDefinition,
+ ToolType,
+ WebSearchTool,
+ WolframAlphaTool,
)
from .base import PromptTemplate, PromptTemplateGeneratorBase
@@ -47,7 +51,7 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase):
def _tool_breakdown(self, tools: List[ToolDefinition]):
builtin_tools, custom_tools = [], []
for dfn in tools:
- if isinstance(dfn.tool_name, BuiltinTool):
+ if dfn.type != ToolType.function.value:
builtin_tools.append(dfn)
else:
custom_tools.append(dfn)
@@ -70,7 +74,11 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase):
return PromptTemplate(
template_str.lstrip("\n"),
{
- "builtin_tools": [t.tool_name.value for t in builtin_tools],
+ "builtin_tools": [
+ # brave_search is used in training data for web_search
+ t.type if t.type != ToolType.web_search.value else "brave_search"
+ for t in builtin_tools
+ ],
"custom_tools": custom_tools,
},
)
@@ -79,19 +87,19 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase):
return [
# builtin tools
[
- ToolDefinition(tool_name=BuiltinTool.code_interpreter),
- ToolDefinition(tool_name=BuiltinTool.brave_search),
- ToolDefinition(tool_name=BuiltinTool.wolfram_alpha),
+ CodeInterpreterTool(),
+ WebSearchTool(),
+ WolframAlphaTool(),
],
# only code interpretor
[
- ToolDefinition(tool_name=BuiltinTool.code_interpreter),
+ CodeInterpreterTool(),
],
]
class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
- def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
+ def gen(self, custom_tools: List[FunctionTool]) -> PromptTemplate:
template_str = textwrap.dedent(
"""
Answer the user's question by making use of the following functions if needed.
@@ -99,7 +107,7 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
Here is a list of functions in JSON format:
{% for t in custom_tools -%}
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
- {%- set tname = t.tool_name -%}
+ {%- set tname = t.name -%}
{%- set tdesc = t.description -%}
{%- set tparams = t.parameters -%}
{%- set required_params = [] -%}
@@ -140,8 +148,8 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
def data_examples(self) -> List[List[ToolDefinition]]:
return [
[
- ToolDefinition(
- tool_name="trending_songs",
+ FunctionTool(
+ name="trending_songs",
description="Returns the trending songs on a Music site",
parameters={
"n": ToolParamDefinition(
@@ -161,14 +169,14 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
- def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
+ def gen(self, custom_tools: List[FunctionTool]) -> PromptTemplate:
template_str = textwrap.dedent(
"""
You have access to the following functions:
{% for t in custom_tools %}
{#- manually setting up JSON because jinja sorts keys in unexpected ways -#}
- {%- set tname = t.tool_name -%}
+ {%- set tname = t.name -%}
{%- set tdesc = t.description -%}
{%- set modified_params = t.parameters.copy() -%}
{%- for key, value in modified_params.items() -%}
@@ -202,8 +210,8 @@ class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
def data_examples(self) -> List[List[ToolDefinition]]:
return [
[
- ToolDefinition(
- tool_name="trending_songs",
+ FunctionTool(
+ name="trending_songs",
description="Returns the trending songs on a Music site",
parameters={
"n": ToolParamDefinition(
@@ -240,7 +248,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
{"function_description": self._gen_function_description(custom_tools)},
)
- def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
+ def _gen_function_description(self, custom_tools: List[FunctionTool]) -> PromptTemplate:
template_str = textwrap.dedent(
"""
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
@@ -252,7 +260,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
[
{% for t in tools -%}
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
- {%- set tname = t.tool_name -%}
+ {%- set tname = t.name -%}
{%- set tdesc = t.description -%}
{%- set tparams = t.parameters -%}
{%- set required_params = [] -%}
@@ -289,8 +297,8 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
def data_examples(self) -> List[List[ToolDefinition]]:
return [
[
- ToolDefinition(
- tool_name="get_weather",
+ FunctionTool(
+ name="get_weather",
description="Get weather info for places",
parameters={
"city": ToolParamDefinition(
diff --git a/llama_stack/models/llama/llama3/tool_utils.py b/llama_stack/models/llama/llama3/tool_utils.py
index 71018898c..cdbf8de1e 100644
--- a/llama_stack/models/llama/llama3/tool_utils.py
+++ b/llama_stack/models/llama/llama3/tool_utils.py
@@ -16,7 +16,7 @@ import re
from typing import Optional, Tuple
from llama_stack.log import get_logger
-from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
+from llama_stack.models.llama.datatypes import RecursiveType, ToolCall, ToolPromptFormat, ToolType
logger = get_logger(name=__name__, category="inference")
@@ -24,6 +24,12 @@ BUILTIN_TOOL_PATTERN = r'\b(?P\w+)\.call\(query="(?P[^"]*)"\)'
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"[^}]+)>(?P{.*?})")
+# The model is trained with brave_search for web_search, so we need to map it
+TOOL_NAME_MAP = {
+ "brave_search": ToolType.web_search.value,
+}
+
+
def is_json(s):
try:
parsed = json.loads(s)
@@ -111,11 +117,6 @@ def parse_python_list_for_function_calls(input_string):
class ToolUtils:
- @staticmethod
- def is_builtin_tool_call(message_body: str) -> bool:
- match = re.search(ToolUtils.BUILTIN_TOOL_PATTERN, message_body)
- return match is not None
-
@staticmethod
def maybe_extract_builtin_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
# Find the first match in the text
@@ -125,7 +126,7 @@ class ToolUtils:
if match:
tool_name = match.group("tool_name")
query = match.group("query")
- return tool_name, query
+ return TOOL_NAME_MAP.get(tool_name, tool_name), query
else:
return None
@@ -143,7 +144,7 @@ class ToolUtils:
tool_name = match.group("function_name")
query = match.group("args")
try:
- return tool_name, json.loads(query.replace("'", '"'))
+ return TOOL_NAME_MAP.get(tool_name, tool_name), json.loads(query.replace("'", '"'))
except Exception as e:
print("Exception while parsing json query for custom tool call", query, e)
return None
@@ -152,30 +153,28 @@ class ToolUtils:
if ("type" in response and response["type"] == "function") or ("name" in response):
function_name = response["name"]
args = response["parameters"]
- return function_name, args
+ return TOOL_NAME_MAP.get(function_name, function_name), args
else:
return None
elif is_valid_python_list(message_body):
res = parse_python_list_for_function_calls(message_body)
# FIXME: Enable multiple tool calls
- return res[0]
+ function_name, args = res[0]
+ return TOOL_NAME_MAP.get(function_name, function_name), args
else:
return None
@staticmethod
def encode_tool_call(t: ToolCall, tool_prompt_format: ToolPromptFormat) -> str:
- if t.tool_name == BuiltinTool.brave_search:
+ if t.type == ToolType.web_search:
q = t.arguments["query"]
return f'brave_search.call(query="{q}")'
- elif t.tool_name == BuiltinTool.wolfram_alpha:
+ elif t.type == ToolType.wolfram_alpha:
q = t.arguments["query"]
return f'wolfram_alpha.call(query="{q}")'
- elif t.tool_name == BuiltinTool.photogen:
- q = t.arguments["query"]
- return f'photogen.call(query="{q}")'
- elif t.tool_name == BuiltinTool.code_interpreter:
+ elif t.type == ToolType.code_interpreter:
return t.arguments["code"]
- else:
+ elif t.type == ToolType.function:
fname = t.tool_name
if tool_prompt_format == ToolPromptFormat.json:
@@ -208,3 +207,5 @@ class ToolUtils:
return f"[{fname}({args_str})]"
else:
raise ValueError(f"Unsupported tool prompt format: {tool_prompt_format}")
+ else:
+ raise ValueError(f"Unsupported tool type: {t.type}")
diff --git a/llama_stack/models/llama/llama3_1/prompts.py b/llama_stack/models/llama/llama3_1/prompts.py
index 9f56bc23b..909aacbfb 100644
--- a/llama_stack/models/llama/llama3_1/prompts.py
+++ b/llama_stack/models/llama/llama3_1/prompts.py
@@ -15,11 +15,11 @@ import textwrap
from typing import List
from llama_stack.models.llama.datatypes import (
- BuiltinTool,
RawMessage,
StopReason,
ToolCall,
ToolPromptFormat,
+ ToolType,
)
from ..prompt_format import (
@@ -184,8 +184,9 @@ def usecases() -> List[UseCase | str]:
stop_reason=StopReason.end_of_message,
tool_calls=[
ToolCall(
+ type=ToolType.wolfram_alpha,
call_id="tool_call_id",
- tool_name=BuiltinTool.wolfram_alpha,
+ tool_name=ToolType.wolfram_alpha.value,
arguments={"query": "100th decimal of pi"},
)
],
diff --git a/llama_stack/models/llama/llama3_3/prompts.py b/llama_stack/models/llama/llama3_3/prompts.py
index 194e4fa26..956f42d69 100644
--- a/llama_stack/models/llama/llama3_3/prompts.py
+++ b/llama_stack/models/llama/llama3_3/prompts.py
@@ -15,11 +15,11 @@ import textwrap
from typing import List
from llama_stack.models.llama.datatypes import (
- BuiltinTool,
RawMessage,
StopReason,
ToolCall,
ToolPromptFormat,
+ ToolType,
)
from ..prompt_format import (
@@ -183,8 +183,9 @@ def usecases() -> List[UseCase | str]:
stop_reason=StopReason.end_of_message,
tool_calls=[
ToolCall(
+ type=ToolType.wolfram_alpha,
call_id="tool_call_id",
- tool_name=BuiltinTool.wolfram_alpha,
+ tool_name=ToolType.wolfram_alpha.value,
arguments={"query": "100th decimal of pi"},
)
],
diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
index fe1726b07..8916bae96 100644
--- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
+++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
@@ -53,7 +53,7 @@ from llama_stack.apis.inference import (
SamplingParams,
StopReason,
SystemMessage,
- ToolDefinition,
+ ToolDefinitionDeprecated,
ToolResponse,
ToolResponseMessage,
UserMessage,
@@ -771,7 +771,7 @@ class ChatAgent(ShieldRunnerMixin):
for tool_def in self.agent_config.client_tools:
if tool_name_to_def.get(tool_def.name, None):
raise ValueError(f"Tool {tool_def.name} already exists")
- tool_name_to_def[tool_def.name] = ToolDefinition(
+ tool_name_to_def[tool_def.name] = ToolDefinitionDeprecated(
tool_name=tool_def.name,
description=tool_def.description,
parameters={
@@ -814,7 +814,7 @@ class ChatAgent(ShieldRunnerMixin):
if tool_name_to_def.get(identifier, None):
raise ValueError(f"Tool {identifier} already exists")
if identifier:
- tool_name_to_def[tool_def.identifier] = ToolDefinition(
+ tool_name_to_def[tool_def.identifier] = ToolDefinitionDeprecated(
tool_name=identifier,
description=tool_def.description,
parameters={
@@ -854,30 +854,23 @@ class ChatAgent(ShieldRunnerMixin):
tool_call: ToolCall,
) -> ToolInvocationResult:
tool_name = tool_call.tool_name
- registered_tool_names = [tool_def.tool_name for tool_def in self.tool_defs]
+ registered_tool_names = list(self.tool_name_to_args.keys())
if tool_name not in registered_tool_names:
raise ValueError(
f"Tool {tool_name} not found in provided tools, registered tools: {', '.join([str(x) for x in registered_tool_names])}"
)
- if isinstance(tool_name, BuiltinTool):
- if tool_name == BuiltinTool.brave_search:
- tool_name_str = WEB_SEARCH_TOOL
- else:
- tool_name_str = tool_name.value
- else:
- tool_name_str = tool_name
- logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
+ logger.info(f"executing tool call: {tool_name} with args: {tool_call.arguments}")
result = await self.tool_runtime_api.invoke_tool(
- tool_name=tool_name_str,
+ tool_name=tool_name,
kwargs={
"session_id": session_id,
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
**tool_call.arguments,
- **self.tool_name_to_args.get(tool_name_str, {}),
+ **self.tool_name_to_args.get(tool_name, {}),
},
)
- logger.debug(f"tool call {tool_name_str} completed with result: {result}")
+ logger.debug(f"tool call {tool_name} completed with result: {result}")
return result
async def handle_documents(
diff --git a/llama_stack/providers/inline/inference/vllm/openai_utils.py b/llama_stack/providers/inline/inference/vllm/openai_utils.py
index 90b5398f9..73a4fcbbe 100644
--- a/llama_stack/providers/inline/inference/vllm/openai_utils.py
+++ b/llama_stack/providers/inline/inference/vllm/openai_utils.py
@@ -16,7 +16,7 @@ from llama_stack.apis.inference import (
ToolChoice,
UserMessage,
)
-from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition
+from llama_stack.models.llama.datatypes import ToolDefinition, ToolType
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
@@ -65,7 +65,7 @@ def _llama_stack_tools_to_openai_tools(
result = []
for t in tools:
- if isinstance(t.tool_name, BuiltinTool):
+ if t.type != ToolType.function.value:
raise NotImplementedError("Built-in tools not yet implemented")
if t.parameters is None:
parameters = None
diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py
index eda1a179c..acdc79ca6 100644
--- a/llama_stack/providers/remote/inference/vllm/vllm.py
+++ b/llama_stack/providers/remote/inference/vllm/vllm.py
@@ -45,7 +45,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
-from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
+from llama_stack.models.llama.datatypes import StopReason, ToolCall
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
@@ -110,6 +110,8 @@ def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]
for tool in tools:
properties = {}
compat_required = []
+
+ tool_name = tool.name
if tool.parameters:
for tool_key, tool_param in tool.parameters.items():
properties[tool_key] = {"type": tool_param.param_type}
@@ -120,12 +122,6 @@ def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]
if tool_param.required:
compat_required.append(tool_key)
- # The tool.tool_name can be a str or a BuiltinTool enum. If
- # it's the latter, convert to a string.
- tool_name = tool.tool_name
- if isinstance(tool_name, BuiltinTool):
- tool_name = tool_name.value
-
compat_tool = {
"type": "function",
"function": {
diff --git a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
index 78b47eb56..56f8f1146 100644
--- a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
+++ b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
@@ -17,7 +17,6 @@ from llama_stack.apis.tools import (
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
-from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import BraveSearchToolConfig
@@ -61,7 +60,6 @@ class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequest
parameter_type="string",
)
],
- built_in_type=BuiltinTool.brave_search,
)
]
diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py
index b264c7312..46bed694b 100644
--- a/llama_stack/providers/utils/inference/openai_compat.py
+++ b/llama_stack/providers/utils/inference/openai_compat.py
@@ -80,12 +80,12 @@ from llama_stack.apis.inference import (
UserMessage,
)
from llama_stack.models.llama.datatypes import (
- BuiltinTool,
GreedySamplingStrategy,
SamplingParams,
StopReason,
ToolCall,
ToolDefinition,
+ ToolType,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
@@ -271,7 +271,7 @@ def process_chat_completion_response(
else:
# only return tool_calls if provided in the request
new_tool_calls = []
- request_tools = {t.tool_name: t for t in request.tools}
+ request_tools = {t.name: t for t in request.tools}
for t in raw_message.tool_calls:
if t.tool_name in request_tools:
new_tool_calls.append(t)
@@ -423,7 +423,7 @@ async def process_chat_completion_stream_response(
)
)
- request_tools = {t.tool_name: t for t in request.tools}
+ request_tools = {t.name: t for t in request.tools}
for tool_call in message.tool_calls:
if tool_call.tool_name in request_tools:
yield ChatCompletionResponseStreamChunk(
@@ -574,7 +574,7 @@ async def convert_message_to_openai_dict_new(
OpenAIChatCompletionMessageToolCall(
id=tool.call_id,
function=OpenAIFunction(
- name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
+ name=tool.tool_name,
arguments=json.dumps(tool.arguments),
),
type="function",
@@ -638,7 +638,7 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
Convert a ToolDefinition to an OpenAI API-compatible dictionary.
ToolDefinition:
- tool_name: str | BuiltinTool
+ tool_name: str
description: Optional[str]
parameters: Optional[Dict[str, ToolParamDefinition]]
@@ -677,10 +677,7 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
}
function = out["function"]
- if isinstance(tool.tool_name, BuiltinTool):
- function.update(name=tool.tool_name.value) # TODO(mf): is this sufficient?
- else:
- function.update(name=tool.tool_name)
+ function.update(name=tool.name)
if tool.description:
function.update(description=tool.description)
@@ -761,6 +758,7 @@ def _convert_openai_tool_calls(
return [
ToolCall(
+ type=ToolType.function,
call_id=call.id,
tool_name=call.function.name,
arguments=json.loads(call.function.arguments),
@@ -975,6 +973,7 @@ async def convert_openai_chat_completion_stream(
try:
arguments = json.loads(buffer["arguments"])
tool_call = ToolCall(
+ type=ToolType.function,
call_id=buffer["call_id"],
tool_name=buffer["name"],
arguments=arguments,
diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py
index 1edf445c0..938c49eef 100644
--- a/llama_stack/providers/utils/inference/prompt_adapter.py
+++ b/llama_stack/providers/utils/inference/prompt_adapter.py
@@ -43,6 +43,7 @@ from llama_stack.models.llama.datatypes import (
Role,
StopReason,
ToolPromptFormat,
+ ToolType,
is_multimodal,
)
from llama_stack.models.llama.llama3.chat_format import ChatFormat
@@ -374,8 +375,8 @@ def augment_messages_for_tools_llama_3_1(
messages.append(SystemMessage(content=sys_content))
- has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
- if has_custom_tools:
+ custom_tools = [t for t in request.tools if t.type == ToolType.function.value]
+ if custom_tools:
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json
if fmt == ToolPromptFormat.json:
tool_gen = JsonCustomToolGenerator()
@@ -384,7 +385,6 @@ def augment_messages_for_tools_llama_3_1(
else:
raise ValueError(f"Non supported ToolPromptFormat {fmt}")
- custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
custom_template = tool_gen.gen(custom_tools)
messages.append(UserMessage(content=custom_template.render()))
@@ -407,7 +407,7 @@ def augment_messages_for_tools_llama_3_2(
sys_content = ""
custom_tools, builtin_tools = [], []
for t in request.tools:
- if isinstance(t.tool_name, str):
+ if t.type == ToolType.function.value:
custom_tools.append(t)
else:
builtin_tools.append(t)
@@ -419,7 +419,7 @@ def augment_messages_for_tools_llama_3_2(
sys_content += tool_template.render()
sys_content += "\n"
- custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
+ custom_tools = [t for t in request.tools if t.type == ToolType.function.value]
if custom_tools:
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list
if fmt != ToolPromptFormat.python_list:
diff --git a/tests/integration/agents/test_agents.py b/tests/integration/agents/test_agents.py
index 7011dc02d..3a2c57fd4 100644
--- a/tests/integration/agents/test_agents.py
+++ b/tests/integration/agents/test_agents.py
@@ -188,16 +188,12 @@ def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent
}
],
session_id=session_id,
+ stream=False,
)
- logs = [str(log) for log in AgentEventLogger().log(response) if log is not None]
- logs_str = "".join(logs)
-
- assert "tool_execution>" in logs_str
- assert "Tool:brave_search Response:" in logs_str
- assert "mark zuckerberg" in logs_str.lower()
- if len(agent_config["output_shields"]) > 0:
- assert "No Violation" in logs_str
+ tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
+ assert tool_execution_step.tool_calls[0].tool_name == "web_search"
+ assert "mark zuckerberg" in response.output_message.content.lower()
def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, agent_config):
diff --git a/tests/integration/inference/test_text_inference.py b/tests/integration/inference/test_text_inference.py
index c9649df60..62bbd536d 100644
--- a/tests/integration/inference/test_text_inference.py
+++ b/tests/integration/inference/test_text_inference.py
@@ -284,6 +284,7 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_
"test_case",
[
"inference:chat_completion:tool_calling",
+ "inference:chat_completion:tool_calling_deprecated",
],
)
def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_models, text_model_id, test_case):
@@ -300,7 +301,9 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_mo
assert response.completion_message.role == "assistant"
assert len(response.completion_message.tool_calls) == 1
- assert response.completion_message.tool_calls[0].tool_name == tc["tools"][0]["tool_name"]
+ assert response.completion_message.tool_calls[0].tool_name == (
+ tc["tools"][0]["tool_name"] if "tool_name" in tc["tools"][0] else tc["tools"][0]["name"]
+ )
assert response.completion_message.tool_calls[0].arguments == tc["expected"]
@@ -334,7 +337,7 @@ def test_text_chat_completion_with_tool_calling_and_streaming(client_with_models
stream=True,
)
tool_invocation_content = extract_tool_invocation_content(response)
- expected_tool_name = tc["tools"][0]["tool_name"]
+ expected_tool_name = tc["tools"][0]["tool_name"] if "tool_name" in tc["tools"][0] else tc["tools"][0]["name"]
expected_argument = tc["expected"]
assert tool_invocation_content == f"[{expected_tool_name}, {expected_argument}]"
@@ -358,7 +361,7 @@ def test_text_chat_completion_with_tool_choice_required(client_with_models, text
stream=True,
)
tool_invocation_content = extract_tool_invocation_content(response)
- expected_tool_name = tc["tools"][0]["tool_name"]
+ expected_tool_name = tc["tools"][0]["tool_name"] if "tool_name" in tc["tools"][0] else tc["tools"][0]["name"]
expected_argument = tc["expected"]
assert tool_invocation_content == f"[{expected_tool_name}, {expected_argument}]"
@@ -432,14 +435,11 @@ def test_text_chat_completion_tool_calling_tools_not_in_request(
):
tc = TestCase(test_case)
- # TODO: more dynamic lookup on tool_prompt_format for model family
- tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
request = {
"model_id": text_model_id,
"messages": tc["messages"],
"tools": tc["tools"],
"tool_choice": "auto",
- "tool_prompt_format": tool_prompt_format,
"stream": streaming,
}
@@ -457,3 +457,30 @@ def test_text_chat_completion_tool_calling_tools_not_in_request(
else:
for tc in response.completion_message.tool_calls:
assert tc.tool_name == "get_object_namespace_list"
+
+
+@pytest.mark.parametrize(
+ "test_case",
+ [
+ "inference:chat_completion:tool_calling_builtin_web_search",
+ "inference:chat_completion:tool_calling_builtin_brave_search",
+ "inference:chat_completion:tool_calling_builtin_code_interpreter",
+ "inference:chat_completion:tool_calling_builtin_code_interpreter_deprecated",
+ ],
+)
+def test_text_chat_completion_tool_calling_builtin(client_with_models, text_model_id, test_case):
+ tc = TestCase(test_case)
+
+ request = {
+ "model_id": text_model_id,
+ "messages": tc["messages"],
+ "tools": tc["tools"],
+ "tool_choice": "auto",
+ "stream": False,
+ }
+
+ response = client_with_models.inference.chat_completion(**request)
+
+ for tool_call in response.completion_message.tool_calls:
+ print(tool_call)
+ assert tool_call.tool_name == tc["expected"]
diff --git a/tests/integration/test_cases/inference/chat_completion.json b/tests/integration/test_cases/inference/chat_completion.json
index e87c046b0..3b47b0816 100644
--- a/tests/integration/test_cases/inference/chat_completion.json
+++ b/tests/integration/test_cases/inference/chat_completion.json
@@ -49,7 +49,7 @@
"expected": "Washington"
}
},
- "tool_calling": {
+ "tool_calling_deprecated": {
"data": {
"messages": [
{"role": "system", "content": "Pretend you are a weather assistant."},
@@ -72,6 +72,30 @@
}
}
},
+ "tool_calling": {
+ "data": {
+ "messages": [
+ {"role": "system", "content": "Pretend you are a weather assistant."},
+ {"role": "user", "content": "What's the weather like in San Francisco?"}
+ ],
+ "tools": [
+ {
+ "type": "function",
+ "name": "get_weather",
+ "description": "Get the current weather",
+ "parameters": {
+ "location": {
+ "param_type": "string",
+ "description": "The city and state (both required), e.g. San Francisco, CA."
+ }
+ }
+ }
+ ],
+ "expected": {
+ "location": "San Francisco, CA"
+ }
+ }
+ },
"sample_messages_tool_calling": {
"data": {
"messages": [
@@ -128,6 +152,56 @@
}
}
},
+ "tool_calling_builtin_web_search": {
+ "data": {
+ "messages": [
+ {"role": "system", "content": "You are a helpful assistant. Use available tools to answer the question."},
+ {"role": "user", "content": "What's the weather like in San Francisco?"}
+ ],
+ "tools": [
+ {
+ "type": "web_search"
+ }
+ ],
+ "expected": "web_search"
+ }
+ },
+ "tool_calling_builtin_brave_search": {
+ "data": {
+ "messages": [
+ {"role": "system", "content": "You are a helpful assistant. Use available tools to answer the question."},
+ {"role": "user", "content": "What's the weather like in San Francisco?"}
+ ],
+ "tools": [{
+ "tool_name": "brave_search"
+ }],
+ "expected": "web_search"
+ }
+ },
+ "tool_calling_builtin_code_interpreter_deprecated": {
+ "data": {
+ "messages": [
+ {"role": "system", "content": "You are a helpful assistant. Use available tools to answer the question."},
+ {"role": "user", "content": "plot log(x) from -10 to 10"}
+ ],
+ "tools": [{
+ "tool_name": "code_interpreter"
+ }],
+ "expected": "code_interpreter"
+ }
+ },
+ "tool_calling_builtin_code_interpreter": {
+ "data": {
+ "messages": [
+ {"role": "system", "content": "You are a helpful assistant. Use available tools to answer the question."},
+ {"role": "user", "content": "plot log(x) from -10 to 10"}
+ ],
+ "tools": [{
+ "type": "code_interpreter"
+ }],
+ "expected": "code_interpreter"
+ }
+ },
"tool_calling_tools_absent": {
"data": {
"messages": [
@@ -146,6 +220,7 @@
"tool_calls": [
{
"call_id": "1",
+ "type": "function",
"tool_name": "get_object_namespace_list",
"arguments": {
"kind": "pod",
diff --git a/tests/unit/models/test_prompt_adapter.py b/tests/unit/models/test_prompt_adapter.py
index 0e2780e50..b2c1ae657 100644
--- a/tests/unit/models/test_prompt_adapter.py
+++ b/tests/unit/models/test_prompt_adapter.py
@@ -5,7 +5,8 @@
# the root directory of this source tree.
import asyncio
-import unittest
+
+import pytest
from llama_stack.apis.inference import (
ChatCompletionRequest,
@@ -17,10 +18,11 @@ from llama_stack.apis.inference import (
UserMessage,
)
from llama_stack.models.llama.datatypes import (
- BuiltinTool,
- ToolDefinition,
+ CodeInterpreterTool,
+ FunctionTool,
ToolParamDefinition,
ToolPromptFormat,
+ WebSearchTool,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
@@ -31,258 +33,269 @@ MODEL = "Llama3.1-8B-Instruct"
MODEL3_2 = "Llama3.2-3B-Instruct"
-class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
- async def asyncSetUp(self):
- asyncio.get_running_loop().set_debug(False)
+@pytest.fixture(autouse=True)
+def setup_loop():
+ loop = asyncio.get_event_loop()
+ loop.set_debug(False)
+ return loop
- async def test_system_default(self):
- content = "Hello !"
- request = ChatCompletionRequest(
- model=MODEL,
- messages=[
- UserMessage(content=content),
- ],
- )
- messages = chat_completion_request_to_messages(request, MODEL)
- self.assertEqual(len(messages), 2)
- self.assertEqual(messages[-1].content, content)
- self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
- async def test_system_builtin_only(self):
- content = "Hello !"
- request = ChatCompletionRequest(
- model=MODEL,
- messages=[
- UserMessage(content=content),
- ],
- tools=[
- ToolDefinition(tool_name=BuiltinTool.code_interpreter),
- ToolDefinition(tool_name=BuiltinTool.brave_search),
- ],
- )
- messages = chat_completion_request_to_messages(request, MODEL)
- self.assertEqual(len(messages), 2)
- self.assertEqual(messages[-1].content, content)
- self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
- self.assertTrue("Tools: brave_search" in messages[0].content)
+@pytest.mark.asyncio
+async def test_system_default():
+ content = "Hello !"
+ request = ChatCompletionRequest(
+ model=MODEL,
+ messages=[
+ UserMessage(content=content),
+ ],
+ )
+ messages = chat_completion_request_to_messages(request, MODEL)
+ assert len(messages) == 2
+ assert messages[-1].content == content
+ assert "Cutting Knowledge Date: December 2023" in messages[0].content
- async def test_system_custom_only(self):
- content = "Hello !"
- request = ChatCompletionRequest(
- model=MODEL,
- messages=[
- UserMessage(content=content),
- ],
- tools=[
- ToolDefinition(
- tool_name="custom1",
- description="custom1 tool",
- parameters={
- "param1": ToolParamDefinition(
- param_type="str",
- description="param1 description",
- required=True,
- ),
- },
- )
- ],
- tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
- )
- messages = chat_completion_request_to_messages(request, MODEL)
- self.assertEqual(len(messages), 3)
- self.assertTrue("Environment: ipython" in messages[0].content)
- self.assertTrue("Return function calls in JSON format" in messages[1].content)
- self.assertEqual(messages[-1].content, content)
+@pytest.mark.asyncio
+async def test_system_builtin_only():
+ content = "Hello !"
+ request = ChatCompletionRequest(
+ model=MODEL,
+ messages=[
+ UserMessage(content=content),
+ ],
+ tools=[
+ CodeInterpreterTool(),
+ WebSearchTool(),
+ ],
+ )
+ messages = chat_completion_request_to_messages(request, MODEL)
+ assert len(messages) == 2
+ assert messages[-1].content == content
+ assert "Cutting Knowledge Date: December 2023" in messages[0].content
+ assert "Tools: brave_search" in messages[0].content
- async def test_system_custom_and_builtin(self):
- content = "Hello !"
- request = ChatCompletionRequest(
- model=MODEL,
- messages=[
- UserMessage(content=content),
- ],
- tools=[
- ToolDefinition(tool_name=BuiltinTool.code_interpreter),
- ToolDefinition(tool_name=BuiltinTool.brave_search),
- ToolDefinition(
- tool_name="custom1",
- description="custom1 tool",
- parameters={
- "param1": ToolParamDefinition(
- param_type="str",
- description="param1 description",
- required=True,
- ),
- },
- ),
- ],
- )
- messages = chat_completion_request_to_messages(request, MODEL)
- self.assertEqual(len(messages), 3)
- self.assertTrue("Environment: ipython" in messages[0].content)
- self.assertTrue("Tools: brave_search" in messages[0].content)
+@pytest.mark.asyncio
+async def test_system_custom_only():
+ content = "Hello !"
+ request = ChatCompletionRequest(
+ model=MODEL,
+ messages=[
+ UserMessage(content=content),
+ ],
+ tools=[
+ FunctionTool(
+ name="custom1",
+ description="custom1 tool",
+ parameters={
+ "param1": ToolParamDefinition(
+ param_type="str",
+ description="param1 description",
+ required=True,
+ ),
+ },
+ )
+ ],
+ tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
+ )
+ messages = chat_completion_request_to_messages(request, MODEL)
+ assert len(messages) == 3
+ assert "Environment: ipython" in messages[0].content
+ assert "Return function calls in JSON format" in messages[1].content
+ assert messages[-1].content == content
- self.assertTrue("Return function calls in JSON format" in messages[1].content)
- self.assertEqual(messages[-1].content, content)
- async def test_completion_message_encoding(self):
- request = ChatCompletionRequest(
- model=MODEL3_2,
- messages=[
- UserMessage(content="hello"),
- CompletionMessage(
- content="",
- stop_reason=StopReason.end_of_turn,
- tool_calls=[
- ToolCall(
- tool_name="custom1",
- arguments={"param1": "value1"},
- call_id="123",
- )
- ],
- ),
- ],
- tools=[
- ToolDefinition(
- tool_name="custom1",
- description="custom1 tool",
- parameters={
- "param1": ToolParamDefinition(
- param_type="str",
- description="param1 description",
- required=True,
- ),
- },
- ),
- ],
- tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
- )
- prompt = await chat_completion_request_to_prompt(request, request.model)
- self.assertIn('[custom1(param1="value1")]', prompt)
-
- request.model = MODEL
- request.tool_config.tool_prompt_format = ToolPromptFormat.json
- prompt = await chat_completion_request_to_prompt(request, request.model)
- self.assertIn(
- '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}',
- prompt,
- )
-
- async def test_user_provided_system_message(self):
- content = "Hello !"
- system_prompt = "You are a pirate"
- request = ChatCompletionRequest(
- model=MODEL,
- messages=[
- SystemMessage(content=system_prompt),
- UserMessage(content=content),
- ],
- tools=[
- ToolDefinition(tool_name=BuiltinTool.code_interpreter),
- ],
- )
- messages = chat_completion_request_to_messages(request, MODEL)
- self.assertEqual(len(messages), 2, messages)
- self.assertTrue(messages[0].content.endswith(system_prompt))
-
- self.assertEqual(messages[-1].content, content)
-
- async def test_repalce_system_message_behavior_builtin_tools(self):
- content = "Hello !"
- system_prompt = "You are a pirate"
- request = ChatCompletionRequest(
- model=MODEL,
- messages=[
- SystemMessage(content=system_prompt),
- UserMessage(content=content),
- ],
- tools=[
- ToolDefinition(tool_name=BuiltinTool.code_interpreter),
- ],
- tool_config=ToolConfig(
- tool_choice="auto",
- tool_prompt_format="python_list",
- system_message_behavior="replace",
+@pytest.mark.asyncio
+async def test_system_custom_and_builtin():
+ content = "Hello !"
+ request = ChatCompletionRequest(
+ model=MODEL,
+ messages=[
+ UserMessage(content=content),
+ ],
+ tools=[
+ CodeInterpreterTool(),
+ WebSearchTool(),
+ FunctionTool(
+ name="custom1",
+ description="custom1 tool",
+ parameters={
+ "param1": ToolParamDefinition(
+ param_type="str",
+ description="param1 description",
+ required=True,
+ ),
+ },
),
- )
- messages = chat_completion_request_to_messages(request, MODEL3_2)
- self.assertEqual(len(messages), 2, messages)
- self.assertTrue(messages[0].content.endswith(system_prompt))
- self.assertIn("Environment: ipython", messages[0].content)
- self.assertEqual(messages[-1].content, content)
+ ],
+ )
+ messages = chat_completion_request_to_messages(request, MODEL)
+ assert len(messages) == 3
+ assert "Environment: ipython" in messages[0].content
+ assert "Tools: brave_search" in messages[0].content
+ assert "Return function calls in JSON format" in messages[1].content
+ assert messages[-1].content == content
- async def test_repalce_system_message_behavior_custom_tools(self):
- content = "Hello !"
- system_prompt = "You are a pirate"
- request = ChatCompletionRequest(
- model=MODEL,
- messages=[
- SystemMessage(content=system_prompt),
- UserMessage(content=content),
- ],
- tools=[
- ToolDefinition(tool_name=BuiltinTool.code_interpreter),
- ToolDefinition(
- tool_name="custom1",
- description="custom1 tool",
- parameters={
- "param1": ToolParamDefinition(
- param_type="str",
- description="param1 description",
- required=True,
- ),
- },
- ),
- ],
- tool_config=ToolConfig(
- tool_choice="auto",
- tool_prompt_format="python_list",
- system_message_behavior="replace",
+
+@pytest.mark.asyncio
+async def test_completion_message_encoding():
+ request = ChatCompletionRequest(
+ model=MODEL3_2,
+ messages=[
+ UserMessage(content="hello"),
+ CompletionMessage(
+ content="",
+ stop_reason=StopReason.end_of_turn,
+ tool_calls=[
+ ToolCall(
+ type="function",
+ tool_name="custom1",
+ arguments={"param1": "value1"},
+ call_id="123",
+ )
+ ],
),
- )
- messages = chat_completion_request_to_messages(request, MODEL3_2)
-
- self.assertEqual(len(messages), 2, messages)
- self.assertTrue(messages[0].content.endswith(system_prompt))
- self.assertIn("Environment: ipython", messages[0].content)
- self.assertEqual(messages[-1].content, content)
-
- async def test_replace_system_message_behavior_custom_tools_with_template(self):
- content = "Hello !"
- system_prompt = "You are a pirate {{ function_description }}"
- request = ChatCompletionRequest(
- model=MODEL,
- messages=[
- SystemMessage(content=system_prompt),
- UserMessage(content=content),
- ],
- tools=[
- ToolDefinition(tool_name=BuiltinTool.code_interpreter),
- ToolDefinition(
- tool_name="custom1",
- description="custom1 tool",
- parameters={
- "param1": ToolParamDefinition(
- param_type="str",
- description="param1 description",
- required=True,
- ),
- },
- ),
- ],
- tool_config=ToolConfig(
- tool_choice="auto",
- tool_prompt_format="python_list",
- system_message_behavior="replace",
+ ],
+ tools=[
+ FunctionTool(
+ name="custom1",
+ description="custom1 tool",
+ parameters={
+ "param1": ToolParamDefinition(
+ param_type="str",
+ description="param1 description",
+ required=True,
+ ),
+ },
),
- )
- messages = chat_completion_request_to_messages(request, MODEL3_2)
+ ],
+ tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
+ )
+ prompt = await chat_completion_request_to_prompt(request, request.model)
+ assert '[custom1(param1="value1")]' in prompt
- self.assertEqual(len(messages), 2, messages)
- self.assertIn("Environment: ipython", messages[0].content)
- self.assertIn("You are a pirate", messages[0].content)
- # function description is present in the system prompt
- self.assertIn('"name": "custom1"', messages[0].content)
- self.assertEqual(messages[-1].content, content)
+ request.model = MODEL
+ request.tool_config.tool_prompt_format = ToolPromptFormat.json
+ prompt = await chat_completion_request_to_prompt(request, request.model)
+ assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt
+
+
+@pytest.mark.asyncio
+async def test_user_provided_system_message():
+ content = "Hello !"
+ system_prompt = "You are a pirate"
+ request = ChatCompletionRequest(
+ model=MODEL,
+ messages=[
+ SystemMessage(content=system_prompt),
+ UserMessage(content=content),
+ ],
+ tools=[
+ CodeInterpreterTool(),
+ ],
+ )
+ messages = chat_completion_request_to_messages(request, MODEL)
+ assert len(messages) == 2
+ assert messages[0].content.endswith(system_prompt)
+ assert messages[-1].content == content
+
+
+@pytest.mark.asyncio
+async def test_repalce_system_message_behavior_builtin_tools():
+ content = "Hello !"
+ system_prompt = "You are a pirate"
+ request = ChatCompletionRequest(
+ model=MODEL,
+ messages=[
+ SystemMessage(content=system_prompt),
+ UserMessage(content=content),
+ ],
+ tools=[
+ CodeInterpreterTool(),
+ ],
+ tool_config=ToolConfig(
+ tool_choice="auto",
+ tool_prompt_format="python_list",
+ system_message_behavior="replace",
+ ),
+ )
+ messages = chat_completion_request_to_messages(request, MODEL3_2)
+ assert len(messages) == 2
+ assert messages[0].content.endswith(system_prompt)
+ assert "Environment: ipython" in messages[0].content
+ assert messages[-1].content == content
+
+
+@pytest.mark.asyncio
+async def test_repalce_system_message_behavior_custom_tools():
+ content = "Hello !"
+ system_prompt = "You are a pirate"
+ request = ChatCompletionRequest(
+ model=MODEL,
+ messages=[
+ SystemMessage(content=system_prompt),
+ UserMessage(content=content),
+ ],
+ tools=[
+ CodeInterpreterTool(),
+ FunctionTool(
+ name="custom1",
+ description="custom1 tool",
+ parameters={
+ "param1": ToolParamDefinition(
+ param_type="str",
+ description="param1 description",
+ required=True,
+ ),
+ },
+ ),
+ ],
+ tool_config=ToolConfig(
+ tool_choice="auto",
+ tool_prompt_format="python_list",
+ system_message_behavior="replace",
+ ),
+ )
+ messages = chat_completion_request_to_messages(request, MODEL3_2)
+ assert len(messages) == 2
+ assert messages[0].content.endswith(system_prompt)
+ assert "Environment: ipython" in messages[0].content
+ assert messages[-1].content == content
+
+
+@pytest.mark.asyncio
+async def test_replace_system_message_behavior_custom_tools_with_template():
+ content = "Hello !"
+ system_prompt = "You are a pirate {{ function_description }}"
+ request = ChatCompletionRequest(
+ model=MODEL,
+ messages=[
+ SystemMessage(content=system_prompt),
+ UserMessage(content=content),
+ ],
+ tools=[
+ CodeInterpreterTool(),
+ FunctionTool(
+ name="custom1",
+ description="custom1 tool",
+ parameters={
+ "param1": ToolParamDefinition(
+ param_type="str",
+ description="param1 description",
+ required=True,
+ ),
+ },
+ ),
+ ],
+ tool_config=ToolConfig(
+ tool_choice="auto",
+ tool_prompt_format="python_list",
+ system_message_behavior="replace",
+ ),
+ )
+ messages = chat_completion_request_to_messages(request, MODEL3_2)
+ assert len(messages) == 2
+ assert "Environment: ipython" in messages[0].content
+ assert "You are a pirate" in messages[0].content
+ assert '"name": "custom1"' in messages[0].content
+ assert messages[-1].content == content
diff --git a/tests/unit/models/test_system_prompts.py b/tests/unit/models/test_system_prompts.py
index 1f4ccc7e3..e068ce775 100644
--- a/tests/unit/models/test_system_prompts.py
+++ b/tests/unit/models/test_system_prompts.py
@@ -33,7 +33,7 @@ class PromptTemplateTests(unittest.TestCase):
if not example:
continue
for tool in example:
- assert tool.tool_name in text
+ assert tool.name in text
def test_system_default(self):
generator = SystemDefaultGenerator()