From 7027b537e015c2d7a11db8632a21ffbba48b869d Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Wed, 26 Mar 2025 11:14:40 -0700 Subject: [PATCH] feat: RFC: tools API rework # What does this PR do? This PR proposes updates to the tools API in Inference and Agent. Goals: 1. Agent's tool specification should be consistent with Inference's tool spec, but with add-ons. 2. Formal types should be defined for built in tools. Currently Agent tools args are untyped, e.g. how does one know that `builtin::rag_tool` takes a `vector_db_ids` param or even how to know 'builtin::rag_tool' is even available (in code, outside of docs)? Inference: 1. BuiltinTool is to be removed and replaced by a formal `type` parameter. 2. 'brave_search' is replaced by 'web_search' to be more generic. It will still be translated back to brave_search when the prompt is constructed to be consistent with model training. 3. I'm not sure what `photogen` is. Maybe it can be removed? Agent: 1. Uses the same format as in Inference for builtin tools. 2. New tools types are added, i.e. knowledge_sesarch (currently rag_tool), and MCP tool. 3. Toolgroup as a concept will be removed since it's really only used for MCP. 4. Instead MCPTool is its own type and available tools provided by the server will be expanded by default. Users can specify a subset of tool names if desired. Example snippet: ``` agent = Agent( client, model=model_id, instructions="You are a helpful assistant. Use the tools you have access to for providing relevant answers.", tools=[ KnowledgeSearchTool(vector_store_id="1234"), KnowledgeSearchTool(vector_store_id="5678", name="paper_search", description="Search research papers"), KnowledgeSearchTool(vector_store_id="1357", name="wiki_search", description="Search wiki pages"), # no need to register toolgroup, just pass in the server uri # all available tools will be used MCPTool(server_uri="http://localhost:8000/sse"), # can specify a subset of available tools MCPTool(server_uri="http://localhost:8000/sse", tool_names=["list_directory"]), MCPTool(server_uri="http://localhost:8000/sse", tool_names=["list_directory"]), # custom tool my_custom_tool, ] ) ``` ## Test Plan # What does this PR do? ## Test Plan # What does this PR do? ## Test Plan --- docs/_static/llama-stack-spec.html | 251 ++++++--- docs/_static/llama-stack-spec.yaml | 161 ++++-- llama_stack/apis/inference/inference.py | 20 +- llama_stack/distribution/routers/routers.py | 36 +- llama_stack/models/llama/datatypes.py | 166 +++++- .../models/llama/llama3/chat_format.py | 17 +- llama_stack/models/llama/llama3/interface.py | 35 +- .../llama3/prompt_templates/system_prompts.py | 46 +- llama_stack/models/llama/llama3/tool_utils.py | 35 +- llama_stack/models/llama/llama3_1/prompts.py | 5 +- llama_stack/models/llama/llama3_3/prompts.py | 5 +- .../agents/meta_reference/agent_instance.py | 23 +- .../inline/inference/vllm/openai_utils.py | 4 +- .../providers/remote/inference/vllm/vllm.py | 10 +- .../tool_runtime/brave_search/brave_search.py | 2 - .../utils/inference/openai_compat.py | 17 +- .../utils/inference/prompt_adapter.py | 10 +- tests/integration/agents/test_agents.py | 12 +- .../inference/test_text_inference.py | 39 +- .../test_cases/inference/chat_completion.json | 77 ++- tests/unit/models/test_prompt_adapter.py | 503 +++++++++--------- tests/unit/models/test_system_prompts.py | 2 +- 22 files changed, 951 insertions(+), 525 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 8a46a89ad..2274f86e7 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -3798,6 +3798,21 @@ ], "title": "AppendRowsRequest" }, + "CodeInterpreterTool": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "code_interpreter", + "default": "code_interpreter" + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "CodeInterpreterTool" + }, "CompletionMessage": { "type": "object", "properties": { @@ -3837,6 +3852,34 @@ "title": "CompletionMessage", "description": "A message containing the model's (assistant) response in a chat conversation." }, + "FunctionTool": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "function", + "default": "function" + }, + "name": { + "type": "string" + }, + "description": { + "type": "string" + }, + "parameters": { + "type": "object", + "additionalProperties": { + "$ref": "#/components/schemas/ToolParamDefinition" + } + } + }, + "additionalProperties": false, + "required": [ + "type", + "name" + ], + "title": "FunctionTool" + }, "GrammarResponseFormat": { "type": "object", "properties": { @@ -4138,25 +4181,21 @@ "ToolCall": { "type": "object", "properties": { + "type": { + "type": "string", + "enum": [ + "function", + "web_search", + "wolfram_alpha", + "code_interpreter" + ], + "title": "ToolType" + }, "call_id": { "type": "string" }, "tool_name": { - "oneOf": [ - { - "type": "string", - "enum": [ - "brave_search", - "wolfram_alpha", - "photogen", - "code_interpreter" - ], - "title": "BuiltinTool" - }, - { - "type": "string" - } - ] + "type": "string" }, "arguments": { "oneOf": [ @@ -4237,48 +4276,13 @@ }, "additionalProperties": false, "required": [ + "type", "call_id", "tool_name", "arguments" ], "title": "ToolCall" }, - "ToolDefinition": { - "type": "object", - "properties": { - "tool_name": { - "oneOf": [ - { - "type": "string", - "enum": [ - "brave_search", - "wolfram_alpha", - "photogen", - "code_interpreter" - ], - "title": "BuiltinTool" - }, - { - "type": "string" - } - ] - }, - "description": { - "type": "string" - }, - "parameters": { - "type": "object", - "additionalProperties": { - "$ref": "#/components/schemas/ToolParamDefinition" - } - } - }, - "additionalProperties": false, - "required": [ - "tool_name" - ], - "title": "ToolDefinition" - }, "ToolParamDefinition": { "type": "object", "properties": { @@ -4428,6 +4432,36 @@ "title": "UserMessage", "description": "A message from the user in a chat conversation." }, + "WebSearchTool": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "web_search", + "default": "web_search" + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "WebSearchTool" + }, + "WolframAlphaTool": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "wolfram_alpha", + "default": "wolfram_alpha" + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "WolframAlphaTool" + }, "BatchChatCompletionRequest": { "type": "object", "properties": { @@ -4449,7 +4483,29 @@ "tools": { "type": "array", "items": { - "$ref": "#/components/schemas/ToolDefinition" + "oneOf": [ + { + "$ref": "#/components/schemas/WebSearchTool" + }, + { + "$ref": "#/components/schemas/WolframAlphaTool" + }, + { + "$ref": "#/components/schemas/CodeInterpreterTool" + }, + { + "$ref": "#/components/schemas/FunctionTool" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "web_search": "#/components/schemas/WebSearchTool", + "wolfram_alpha": "#/components/schemas/WolframAlphaTool", + "code_interpreter": "#/components/schemas/CodeInterpreterTool", + "function": "#/components/schemas/FunctionTool" + } + } } }, "tool_choice": { @@ -4734,6 +4790,41 @@ "title": "ToolConfig", "description": "Configuration for tool use." }, + "ToolDefinitionDeprecated": { + "type": "object", + "properties": { + "tool_name": { + "oneOf": [ + { + "type": "string", + "enum": [ + "brave_search", + "wolfram_alpha", + "code_interpreter" + ], + "title": "BuiltinTool" + }, + { + "type": "string" + } + ] + }, + "description": { + "type": "string" + }, + "parameters": { + "type": "object", + "additionalProperties": { + "$ref": "#/components/schemas/ToolParamDefinition" + } + } + }, + "additionalProperties": false, + "required": [ + "tool_name" + ], + "title": "ToolDefinitionDeprecated" + }, "ChatCompletionRequest": { "type": "object", "properties": { @@ -4753,10 +4844,42 @@ "description": "Parameters to control the sampling strategy" }, "tools": { - "type": "array", - "items": { - "$ref": "#/components/schemas/ToolDefinition" - }, + "oneOf": [ + { + "type": "array", + "items": { + "oneOf": [ + { + "$ref": "#/components/schemas/WebSearchTool" + }, + { + "$ref": "#/components/schemas/WolframAlphaTool" + }, + { + "$ref": "#/components/schemas/CodeInterpreterTool" + }, + { + "$ref": "#/components/schemas/FunctionTool" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "web_search": "#/components/schemas/WebSearchTool", + "wolfram_alpha": "#/components/schemas/WolframAlphaTool", + "code_interpreter": "#/components/schemas/CodeInterpreterTool", + "function": "#/components/schemas/FunctionTool" + } + } + } + }, + { + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolDefinitionDeprecated" + } + } + ], "description": "(Optional) List of tool definitions available to the model" }, "tool_choice": { @@ -5630,21 +5753,7 @@ "type": "string" }, "tool_name": { - "oneOf": [ - { - "type": "string", - "enum": [ - "brave_search", - "wolfram_alpha", - "photogen", - "code_interpreter" - ], - "title": "BuiltinTool" - }, - { - "type": "string" - } - ] + "type": "string" }, "content": { "$ref": "#/components/schemas/InterleavedContent" diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 0b8f90490..d1517bd1e 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -2607,6 +2607,17 @@ components: required: - rows title: AppendRowsRequest + CodeInterpreterTool: + type: object + properties: + type: + type: string + const: code_interpreter + default: code_interpreter + additionalProperties: false + required: + - type + title: CodeInterpreterTool CompletionMessage: type: object properties: @@ -2646,6 +2657,26 @@ components: title: CompletionMessage description: >- A message containing the model's (assistant) response in a chat conversation. + FunctionTool: + type: object + properties: + type: + type: string + const: function + default: function + name: + type: string + description: + type: string + parameters: + type: object + additionalProperties: + $ref: '#/components/schemas/ToolParamDefinition' + additionalProperties: false + required: + - type + - name + title: FunctionTool GrammarResponseFormat: type: object properties: @@ -2851,18 +2882,18 @@ components: ToolCall: type: object properties: + type: + type: string + enum: + - function + - web_search + - wolfram_alpha + - code_interpreter + title: ToolType call_id: type: string tool_name: - oneOf: - - type: string - enum: - - brave_search - - wolfram_alpha - - photogen - - code_interpreter - title: BuiltinTool - - type: string + type: string arguments: oneOf: - type: string @@ -2894,33 +2925,11 @@ components: type: string additionalProperties: false required: + - type - call_id - tool_name - arguments title: ToolCall - ToolDefinition: - type: object - properties: - tool_name: - oneOf: - - type: string - enum: - - brave_search - - wolfram_alpha - - photogen - - code_interpreter - title: BuiltinTool - - type: string - description: - type: string - parameters: - type: object - additionalProperties: - $ref: '#/components/schemas/ToolParamDefinition' - additionalProperties: false - required: - - tool_name - title: ToolDefinition ToolParamDefinition: type: object properties: @@ -3031,6 +3040,28 @@ components: title: UserMessage description: >- A message from the user in a chat conversation. + WebSearchTool: + type: object + properties: + type: + type: string + const: web_search + default: web_search + additionalProperties: false + required: + - type + title: WebSearchTool + WolframAlphaTool: + type: object + properties: + type: + type: string + const: wolfram_alpha + default: wolfram_alpha + additionalProperties: false + required: + - type + title: WolframAlphaTool BatchChatCompletionRequest: type: object properties: @@ -3047,7 +3078,18 @@ components: tools: type: array items: - $ref: '#/components/schemas/ToolDefinition' + oneOf: + - $ref: '#/components/schemas/WebSearchTool' + - $ref: '#/components/schemas/WolframAlphaTool' + - $ref: '#/components/schemas/CodeInterpreterTool' + - $ref: '#/components/schemas/FunctionTool' + discriminator: + propertyName: type + mapping: + web_search: '#/components/schemas/WebSearchTool' + wolfram_alpha: '#/components/schemas/WolframAlphaTool' + code_interpreter: '#/components/schemas/CodeInterpreterTool' + function: '#/components/schemas/FunctionTool' tool_choice: type: string enum: @@ -3272,6 +3314,28 @@ components: additionalProperties: false title: ToolConfig description: Configuration for tool use. + ToolDefinitionDeprecated: + type: object + properties: + tool_name: + oneOf: + - type: string + enum: + - brave_search + - wolfram_alpha + - code_interpreter + title: BuiltinTool + - type: string + description: + type: string + parameters: + type: object + additionalProperties: + $ref: '#/components/schemas/ToolParamDefinition' + additionalProperties: false + required: + - tool_name + title: ToolDefinitionDeprecated ChatCompletionRequest: type: object properties: @@ -3290,9 +3354,24 @@ components: description: >- Parameters to control the sampling strategy tools: - type: array - items: - $ref: '#/components/schemas/ToolDefinition' + oneOf: + - type: array + items: + oneOf: + - $ref: '#/components/schemas/WebSearchTool' + - $ref: '#/components/schemas/WolframAlphaTool' + - $ref: '#/components/schemas/CodeInterpreterTool' + - $ref: '#/components/schemas/FunctionTool' + discriminator: + propertyName: type + mapping: + web_search: '#/components/schemas/WebSearchTool' + wolfram_alpha: '#/components/schemas/WolframAlphaTool' + code_interpreter: '#/components/schemas/CodeInterpreterTool' + function: '#/components/schemas/FunctionTool' + - type: array + items: + $ref: '#/components/schemas/ToolDefinitionDeprecated' description: >- (Optional) List of tool definitions available to the model tool_choice: @@ -3939,15 +4018,7 @@ components: call_id: type: string tool_name: - oneOf: - - type: string - enum: - - brave_search - - wolfram_alpha - - photogen - - code_interpreter - title: BuiltinTool - - type: string + type: string content: $ref: '#/components/schemas/InterleavedContent' metadata: diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 7d3539dcb..c51f4b971 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -17,18 +17,18 @@ from typing import ( runtime_checkable, ) -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field from typing_extensions import Annotated from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem from llama_stack.apis.models import Model from llama_stack.apis.telemetry.telemetry import MetricResponseMixin from llama_stack.models.llama.datatypes import ( - BuiltinTool, SamplingParams, StopReason, ToolCall, ToolDefinition, + ToolDefinitionDeprecated, ToolPromptFormat, ) from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol @@ -156,23 +156,14 @@ Message = Annotated[ register_schema(Message, name="Message") +# TODO: move this to agent.py where this is used @json_schema_type class ToolResponse(BaseModel): call_id: str - tool_name: Union[BuiltinTool, str] + tool_name: str content: InterleavedContent metadata: Optional[Dict[str, Any]] = None - @field_validator("tool_name", mode="before") - @classmethod - def validate_field(cls, v): - if isinstance(v, str): - try: - return BuiltinTool(v) - except ValueError: - return v - return v - class ToolChoice(Enum): """Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model. @@ -462,7 +453,8 @@ class Inference(Protocol): model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = None, - tools: Optional[List[ToolDefinition]] = None, + # TODO: remove ToolDefinitionDeprecated in v0.1.10 + tools: Optional[List[ToolDefinition] | List[ToolDefinitionDeprecated]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = None, response_format: Optional[ResponseFormat] = None, diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 6ff36a65c..764fa2406 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -32,6 +32,7 @@ from llama_stack.apis.inference import ( ToolChoice, ToolConfig, ToolDefinition, + ToolDefinitionDeprecated, ToolPromptFormat, ) from llama_stack.apis.models import Model, ModelType @@ -54,6 +55,9 @@ from llama_stack.apis.tools import ( ) from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.log import get_logger +from llama_stack.models.llama.datatypes import ( + ToolType, +) from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.providers.datatypes import RoutingTable @@ -229,7 +233,7 @@ class InferenceRouter(Inference): messages: List[Message], sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, - tools: Optional[List[ToolDefinition]] = None, + tools: Optional[List[ToolDefinition] | List[ToolDefinitionDeprecated]] = None, tool_choice: Optional[ToolChoice] = None, tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, @@ -259,24 +263,42 @@ class InferenceRouter(Inference): params["tool_prompt_format"] = tool_prompt_format tool_config = ToolConfig(**params) - tools = tools or [] + # TODO: remove ToolDefinitionDeprecated in v0.1.10 + converted_tools = [] + for tool in tools or []: + if isinstance(tool, ToolDefinitionDeprecated): + logger.warning(f"ToolDefinitionDeprecated: {tool}, use ToolDefinition instead") + converted_tools.append(tool.to_tool_definition()) + else: + converted_tools.append(tool) + if tool_config.tool_choice == ToolChoice.none: - tools = [] + converted_tools = [] elif tool_config.tool_choice == ToolChoice.auto: pass elif tool_config.tool_choice == ToolChoice.required: pass else: # verify tool_choice is one of the tools - tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools] - if tool_config.tool_choice not in tool_names: - raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}") + for t in converted_tools: + if t.type == ToolType.function.value: + if tool_config.tool_choice == t.name: + break + elif t.type in ( + ToolType.web_search.value, + ToolType.wolfram_alpha.value, + ToolType.code_interpreter.value, + ): + if tool_config.tool_choice == t.type: + break + else: + raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {converted_tools}") params = dict( model_id=model_id, messages=messages, sampling_params=sampling_params, - tools=tools, + tools=converted_tools, tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, response_format=response_format, diff --git a/llama_stack/models/llama/datatypes.py b/llama_stack/models/llama/datatypes.py index f762eb50f..f1c280f32 100644 --- a/llama_stack/models/llama/datatypes.py +++ b/llama_stack/models/llama/datatypes.py @@ -33,10 +33,10 @@ class Role(Enum): tool = "tool" -class BuiltinTool(Enum): - brave_search = "brave_search" +class ToolType(Enum): + function = "function" + web_search = "web_search" wolfram_alpha = "wolfram_alpha" - photogen = "photogen" code_interpreter = "code_interpreter" @@ -45,8 +45,9 @@ RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]] class ToolCall(BaseModel): + type: ToolType call_id: str - tool_name: Union[BuiltinTool, str] + tool_name: str # Plan is to deprecate the Dict in favor of a JSON string # that is parsed on the client side instead of trying to manage # the recursive type here. @@ -59,12 +60,18 @@ class ToolCall(BaseModel): @field_validator("tool_name", mode="before") @classmethod def validate_field(cls, v): + # for backwards compatibility, we allow the tool name to be a string or a BuiltinTool + # TODO: remove ToolDefinitionDeprecated in v0.1.10 + tool_name = v if isinstance(v, str): try: - return BuiltinTool(v) + tool_name = BuiltinTool(v) except ValueError: - return v - return v + pass + + if isinstance(tool_name, BuiltinTool): + return tool_name.to_tool().type + return tool_name class ToolPromptFormat(Enum): @@ -151,8 +158,136 @@ class ToolParamDefinition(BaseModel): default: Optional[Any] = None +class Tool(BaseModel): + type: ToolType + + @classmethod + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + + required_properties = ["name", "description", "parameters"] + for prop in required_properties: + has_property = any(isinstance(v, property) for v in [cls.__dict__.get(prop)]) + has_field = prop in cls.__annotations__ or prop in cls.__dict__ + if not has_property and not has_field: + raise TypeError(f"Class {cls.__name__} must implement '{prop}' property or field") + + @json_schema_type -class ToolDefinition(BaseModel): +class WebSearchTool(Tool): + type: Literal[ToolType.web_search.value] = ToolType.web_search.value + + @property + def name(self) -> str: + return "web_search" + + @property + def description(self) -> str: + return "Search the web for information" + + @property + def parameters(self) -> Dict[str, ToolParamDefinition]: + return { + "query": ToolParamDefinition( + description="The query to search for", + param_type="string", + required=True, + ), + } + + +@json_schema_type +class WolframAlphaTool(Tool): + type: Literal[ToolType.wolfram_alpha.value] = ToolType.wolfram_alpha.value + + @property + def name(self) -> str: + return "wolfram_alpha" + + @property + def description(self) -> str: + return "Query WolframAlpha for computational knowledge" + + @property + def parameters(self) -> Dict[str, ToolParamDefinition]: + return { + "query": ToolParamDefinition( + description="The query to compute", + param_type="string", + required=True, + ), + } + + +@json_schema_type +class CodeInterpreterTool(Tool): + type: Literal[ToolType.code_interpreter.value] = ToolType.code_interpreter.value + + @property + def name(self) -> str: + return "code_interpreter" + + @property + def description(self) -> str: + return "Execute code" + + @property + def parameters(self) -> Dict[str, ToolParamDefinition]: + return { + "code": ToolParamDefinition( + description="The code to execute", + param_type="string", + required=True, + ), + } + + +@json_schema_type +class FunctionTool(Tool): + type: Literal[ToolType.function.value] = ToolType.function.value + name: str + description: Optional[str] = None + parameters: Optional[Dict[str, ToolParamDefinition]] = None + + @field_validator("name", mode="before") + @classmethod + def validate_name(cls, v): + if v in ToolType.__members__: + raise ValueError(f"Tool name '{v}' is a tool type and cannot be used as a name of a function tool") + return v + + +ToolDefinition = Annotated[ + Union[WebSearchTool, WolframAlphaTool, CodeInterpreterTool, FunctionTool], Field(discriminator="type") +] + + +# TODO: remove ToolDefinitionDeprecated in v0.1.10 +class BuiltinTool(Enum): + brave_search = "brave_search" + wolfram_alpha = "wolfram_alpha" + code_interpreter = "code_interpreter" + + def to_tool_type(self) -> ToolType: + if self == BuiltinTool.brave_search: + return ToolType.web_search + elif self == BuiltinTool.wolfram_alpha: + return ToolType.wolfram_alpha + elif self == BuiltinTool.code_interpreter: + return ToolType.code_interpreter + + def to_tool(self) -> WebSearchTool | WolframAlphaTool | CodeInterpreterTool: + if self == BuiltinTool.brave_search: + return WebSearchTool() + elif self == BuiltinTool.wolfram_alpha: + return WolframAlphaTool() + elif self == BuiltinTool.code_interpreter: + return CodeInterpreterTool() + + +# TODO: remove ToolDefinitionDeprecated in v0.1.10 +@json_schema_type +class ToolDefinitionDeprecated(BaseModel): tool_name: Union[BuiltinTool, str] description: Optional[str] = None parameters: Optional[Dict[str, ToolParamDefinition]] = None @@ -167,6 +302,21 @@ class ToolDefinition(BaseModel): return v return v + def to_tool_definition(self) -> ToolDefinition: + # convert to ToolDefinition + if self.tool_name == BuiltinTool.brave_search: + return WebSearchTool() + elif self.tool_name == BuiltinTool.code_interpreter: + return CodeInterpreterTool() + elif self.tool_name == BuiltinTool.wolfram_alpha: + return WolframAlphaTool() + else: + return FunctionTool( + name=self.tool_name, + description=self.description, + parameters=self.parameters, + ) + @json_schema_type class GreedySamplingStrategy(BaseModel): diff --git a/llama_stack/models/llama/llama3/chat_format.py b/llama_stack/models/llama/llama3/chat_format.py index 2862f8558..e1c44cff0 100644 --- a/llama_stack/models/llama/llama3/chat_format.py +++ b/llama_stack/models/llama/llama3/chat_format.py @@ -20,7 +20,6 @@ from typing import Dict, List, Optional, Tuple from PIL import Image as PIL_Image from llama_stack.models.llama.datatypes import ( - BuiltinTool, RawContent, RawMediaItem, RawMessage, @@ -29,6 +28,7 @@ from llama_stack.models.llama.datatypes import ( StopReason, ToolCall, ToolPromptFormat, + ToolType, ) from .tokenizer import Tokenizer @@ -127,7 +127,7 @@ class ChatFormat: if ( message.role == "assistant" and len(message.tool_calls) > 0 - and message.tool_calls[0].tool_name == BuiltinTool.code_interpreter + and message.tool_calls[0].type == ToolType.code_interpreter ): tokens.append(self.tokenizer.special_tokens["<|python_tag|>"]) @@ -194,6 +194,7 @@ class ChatFormat: stop_reason = StopReason.end_of_message tool_name = None + tool_type = ToolType.function tool_arguments = {} custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content) @@ -202,8 +203,8 @@ class ChatFormat: # Sometimes when agent has custom tools alongside builin tools # Agent responds for builtin tool calls in the format of the custom tools # This code tries to handle that case - if tool_name in BuiltinTool.__members__: - tool_name = BuiltinTool[tool_name] + if tool_name in ToolType.__members__: + tool_type = ToolType[tool_name] if isinstance(tool_arguments, dict): tool_arguments = { "query": list(tool_arguments.values())[0], @@ -215,10 +216,11 @@ class ChatFormat: tool_arguments = { "query": query, } - if tool_name in BuiltinTool.__members__: - tool_name = BuiltinTool[tool_name] + if tool_name in ToolType.__members__: + tool_type = ToolType[tool_name] elif ipython: - tool_name = BuiltinTool.code_interpreter + tool_name = ToolType.code_interpreter.value + tool_type = ToolType.code_interpreter tool_arguments = { "code": content, } @@ -228,6 +230,7 @@ class ChatFormat: call_id = str(uuid.uuid4()) tool_calls.append( ToolCall( + type=tool_type, call_id=call_id, tool_name=tool_name, arguments=tool_arguments, diff --git a/llama_stack/models/llama/llama3/interface.py b/llama_stack/models/llama/llama3/interface.py index 2579ab6c8..5abace9e2 100644 --- a/llama_stack/models/llama/llama3/interface.py +++ b/llama_stack/models/llama/llama3/interface.py @@ -17,7 +17,7 @@ from typing import List, Optional from termcolor import colored from llama_stack.models.llama.datatypes import ( - BuiltinTool, + FunctionTool, RawMessage, StopReason, ToolCall, @@ -25,7 +25,6 @@ from llama_stack.models.llama.datatypes import ( ToolPromptFormat, ) -from . import template_data from .chat_format import ChatFormat from .prompt_templates import ( BuiltinToolGenerator, @@ -150,8 +149,8 @@ class LLama31Interface: def system_messages( self, - builtin_tools: List[BuiltinTool], - custom_tools: List[ToolDefinition], + builtin_tools: List[ToolDefinition], + custom_tools: List[FunctionTool], instruction: Optional[str] = None, ) -> List[RawMessage]: messages = [] @@ -227,31 +226,3 @@ class LLama31Interface: on_col = on_colors[i % len(on_colors)] print(colored(self.tokenizer.decode([t]), "white", on_col), end="") print("\n", end="") - - -def list_jinja_templates() -> List[Template]: - return TEMPLATES - - -def render_jinja_template(name: str, tool_prompt_format: ToolPromptFormat): - by_name = {t.template_name: t for t in TEMPLATES} - if name not in by_name: - raise ValueError(f"No template found for `{name}`") - - template = by_name[name] - interface = LLama31Interface(tool_prompt_format) - - data_func = getattr(template_data, template.data_provider) - if template.role == "system": - messages = interface.system_messages(**data_func()) - elif template.role == "tool": - messages = interface.tool_response_messages(**data_func()) - elif template.role == "assistant": - messages = interface.assistant_response_messages(**data_func()) - elif template.role == "user": - messages = interface.user_message(**data_func()) - - tokens = interface.get_tokens(messages) - special_tokens = list(interface.tokenizer.special_tokens.values()) - tokens = [(interface.tokenizer.decode([t]), t in special_tokens) for t in tokens] - return template, tokens diff --git a/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py b/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py index 9da6a640e..2cd69f5d0 100644 --- a/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py +++ b/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py @@ -16,9 +16,13 @@ from datetime import datetime from typing import Any, List, Optional from llama_stack.models.llama.datatypes import ( - BuiltinTool, + CodeInterpreterTool, + FunctionTool, ToolDefinition, ToolParamDefinition, + ToolType, + WebSearchTool, + WolframAlphaTool, ) from .base import PromptTemplate, PromptTemplateGeneratorBase @@ -47,7 +51,7 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase): def _tool_breakdown(self, tools: List[ToolDefinition]): builtin_tools, custom_tools = [], [] for dfn in tools: - if isinstance(dfn.tool_name, BuiltinTool): + if dfn.type != ToolType.function.value: builtin_tools.append(dfn) else: custom_tools.append(dfn) @@ -70,7 +74,11 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase): return PromptTemplate( template_str.lstrip("\n"), { - "builtin_tools": [t.tool_name.value for t in builtin_tools], + "builtin_tools": [ + # brave_search is used in training data for web_search + t.type if t.type != ToolType.web_search.value else "brave_search" + for t in builtin_tools + ], "custom_tools": custom_tools, }, ) @@ -79,19 +87,19 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase): return [ # builtin tools [ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition(tool_name=BuiltinTool.brave_search), - ToolDefinition(tool_name=BuiltinTool.wolfram_alpha), + CodeInterpreterTool(), + WebSearchTool(), + WolframAlphaTool(), ], # only code interpretor [ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), + CodeInterpreterTool(), ], ] class JsonCustomToolGenerator(PromptTemplateGeneratorBase): - def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: + def gen(self, custom_tools: List[FunctionTool]) -> PromptTemplate: template_str = textwrap.dedent( """ Answer the user's question by making use of the following functions if needed. @@ -99,7 +107,7 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase): Here is a list of functions in JSON format: {% for t in custom_tools -%} {# manually setting up JSON because jinja sorts keys in unexpected ways -#} - {%- set tname = t.tool_name -%} + {%- set tname = t.name -%} {%- set tdesc = t.description -%} {%- set tparams = t.parameters -%} {%- set required_params = [] -%} @@ -140,8 +148,8 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase): def data_examples(self) -> List[List[ToolDefinition]]: return [ [ - ToolDefinition( - tool_name="trending_songs", + FunctionTool( + name="trending_songs", description="Returns the trending songs on a Music site", parameters={ "n": ToolParamDefinition( @@ -161,14 +169,14 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase): class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase): - def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: + def gen(self, custom_tools: List[FunctionTool]) -> PromptTemplate: template_str = textwrap.dedent( """ You have access to the following functions: {% for t in custom_tools %} {#- manually setting up JSON because jinja sorts keys in unexpected ways -#} - {%- set tname = t.tool_name -%} + {%- set tname = t.name -%} {%- set tdesc = t.description -%} {%- set modified_params = t.parameters.copy() -%} {%- for key, value in modified_params.items() -%} @@ -202,8 +210,8 @@ class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase): def data_examples(self) -> List[List[ToolDefinition]]: return [ [ - ToolDefinition( - tool_name="trending_songs", + FunctionTool( + name="trending_songs", description="Returns the trending songs on a Music site", parameters={ "n": ToolParamDefinition( @@ -240,7 +248,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801 {"function_description": self._gen_function_description(custom_tools)}, ) - def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: + def _gen_function_description(self, custom_tools: List[FunctionTool]) -> PromptTemplate: template_str = textwrap.dedent( """ If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] @@ -252,7 +260,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801 [ {% for t in tools -%} {# manually setting up JSON because jinja sorts keys in unexpected ways -#} - {%- set tname = t.tool_name -%} + {%- set tname = t.name -%} {%- set tdesc = t.description -%} {%- set tparams = t.parameters -%} {%- set required_params = [] -%} @@ -289,8 +297,8 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801 def data_examples(self) -> List[List[ToolDefinition]]: return [ [ - ToolDefinition( - tool_name="get_weather", + FunctionTool( + name="get_weather", description="Get weather info for places", parameters={ "city": ToolParamDefinition( diff --git a/llama_stack/models/llama/llama3/tool_utils.py b/llama_stack/models/llama/llama3/tool_utils.py index 71018898c..cdbf8de1e 100644 --- a/llama_stack/models/llama/llama3/tool_utils.py +++ b/llama_stack/models/llama/llama3/tool_utils.py @@ -16,7 +16,7 @@ import re from typing import Optional, Tuple from llama_stack.log import get_logger -from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat +from llama_stack.models.llama.datatypes import RecursiveType, ToolCall, ToolPromptFormat, ToolType logger = get_logger(name=__name__, category="inference") @@ -24,6 +24,12 @@ BUILTIN_TOOL_PATTERN = r'\b(?P\w+)\.call\(query="(?P[^"]*)"\)' CUSTOM_TOOL_CALL_PATTERN = re.compile(r"[^}]+)>(?P{.*?})") +# The model is trained with brave_search for web_search, so we need to map it +TOOL_NAME_MAP = { + "brave_search": ToolType.web_search.value, +} + + def is_json(s): try: parsed = json.loads(s) @@ -111,11 +117,6 @@ def parse_python_list_for_function_calls(input_string): class ToolUtils: - @staticmethod - def is_builtin_tool_call(message_body: str) -> bool: - match = re.search(ToolUtils.BUILTIN_TOOL_PATTERN, message_body) - return match is not None - @staticmethod def maybe_extract_builtin_tool_call(message_body: str) -> Optional[Tuple[str, str]]: # Find the first match in the text @@ -125,7 +126,7 @@ class ToolUtils: if match: tool_name = match.group("tool_name") query = match.group("query") - return tool_name, query + return TOOL_NAME_MAP.get(tool_name, tool_name), query else: return None @@ -143,7 +144,7 @@ class ToolUtils: tool_name = match.group("function_name") query = match.group("args") try: - return tool_name, json.loads(query.replace("'", '"')) + return TOOL_NAME_MAP.get(tool_name, tool_name), json.loads(query.replace("'", '"')) except Exception as e: print("Exception while parsing json query for custom tool call", query, e) return None @@ -152,30 +153,28 @@ class ToolUtils: if ("type" in response and response["type"] == "function") or ("name" in response): function_name = response["name"] args = response["parameters"] - return function_name, args + return TOOL_NAME_MAP.get(function_name, function_name), args else: return None elif is_valid_python_list(message_body): res = parse_python_list_for_function_calls(message_body) # FIXME: Enable multiple tool calls - return res[0] + function_name, args = res[0] + return TOOL_NAME_MAP.get(function_name, function_name), args else: return None @staticmethod def encode_tool_call(t: ToolCall, tool_prompt_format: ToolPromptFormat) -> str: - if t.tool_name == BuiltinTool.brave_search: + if t.type == ToolType.web_search: q = t.arguments["query"] return f'brave_search.call(query="{q}")' - elif t.tool_name == BuiltinTool.wolfram_alpha: + elif t.type == ToolType.wolfram_alpha: q = t.arguments["query"] return f'wolfram_alpha.call(query="{q}")' - elif t.tool_name == BuiltinTool.photogen: - q = t.arguments["query"] - return f'photogen.call(query="{q}")' - elif t.tool_name == BuiltinTool.code_interpreter: + elif t.type == ToolType.code_interpreter: return t.arguments["code"] - else: + elif t.type == ToolType.function: fname = t.tool_name if tool_prompt_format == ToolPromptFormat.json: @@ -208,3 +207,5 @@ class ToolUtils: return f"[{fname}({args_str})]" else: raise ValueError(f"Unsupported tool prompt format: {tool_prompt_format}") + else: + raise ValueError(f"Unsupported tool type: {t.type}") diff --git a/llama_stack/models/llama/llama3_1/prompts.py b/llama_stack/models/llama/llama3_1/prompts.py index 9f56bc23b..909aacbfb 100644 --- a/llama_stack/models/llama/llama3_1/prompts.py +++ b/llama_stack/models/llama/llama3_1/prompts.py @@ -15,11 +15,11 @@ import textwrap from typing import List from llama_stack.models.llama.datatypes import ( - BuiltinTool, RawMessage, StopReason, ToolCall, ToolPromptFormat, + ToolType, ) from ..prompt_format import ( @@ -184,8 +184,9 @@ def usecases() -> List[UseCase | str]: stop_reason=StopReason.end_of_message, tool_calls=[ ToolCall( + type=ToolType.wolfram_alpha, call_id="tool_call_id", - tool_name=BuiltinTool.wolfram_alpha, + tool_name=ToolType.wolfram_alpha.value, arguments={"query": "100th decimal of pi"}, ) ], diff --git a/llama_stack/models/llama/llama3_3/prompts.py b/llama_stack/models/llama/llama3_3/prompts.py index 194e4fa26..956f42d69 100644 --- a/llama_stack/models/llama/llama3_3/prompts.py +++ b/llama_stack/models/llama/llama3_3/prompts.py @@ -15,11 +15,11 @@ import textwrap from typing import List from llama_stack.models.llama.datatypes import ( - BuiltinTool, RawMessage, StopReason, ToolCall, ToolPromptFormat, + ToolType, ) from ..prompt_format import ( @@ -183,8 +183,9 @@ def usecases() -> List[UseCase | str]: stop_reason=StopReason.end_of_message, tool_calls=[ ToolCall( + type=ToolType.wolfram_alpha, call_id="tool_call_id", - tool_name=BuiltinTool.wolfram_alpha, + tool_name=ToolType.wolfram_alpha.value, arguments={"query": "100th decimal of pi"}, ) ], diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index fe1726b07..8916bae96 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -53,7 +53,7 @@ from llama_stack.apis.inference import ( SamplingParams, StopReason, SystemMessage, - ToolDefinition, + ToolDefinitionDeprecated, ToolResponse, ToolResponseMessage, UserMessage, @@ -771,7 +771,7 @@ class ChatAgent(ShieldRunnerMixin): for tool_def in self.agent_config.client_tools: if tool_name_to_def.get(tool_def.name, None): raise ValueError(f"Tool {tool_def.name} already exists") - tool_name_to_def[tool_def.name] = ToolDefinition( + tool_name_to_def[tool_def.name] = ToolDefinitionDeprecated( tool_name=tool_def.name, description=tool_def.description, parameters={ @@ -814,7 +814,7 @@ class ChatAgent(ShieldRunnerMixin): if tool_name_to_def.get(identifier, None): raise ValueError(f"Tool {identifier} already exists") if identifier: - tool_name_to_def[tool_def.identifier] = ToolDefinition( + tool_name_to_def[tool_def.identifier] = ToolDefinitionDeprecated( tool_name=identifier, description=tool_def.description, parameters={ @@ -854,30 +854,23 @@ class ChatAgent(ShieldRunnerMixin): tool_call: ToolCall, ) -> ToolInvocationResult: tool_name = tool_call.tool_name - registered_tool_names = [tool_def.tool_name for tool_def in self.tool_defs] + registered_tool_names = list(self.tool_name_to_args.keys()) if tool_name not in registered_tool_names: raise ValueError( f"Tool {tool_name} not found in provided tools, registered tools: {', '.join([str(x) for x in registered_tool_names])}" ) - if isinstance(tool_name, BuiltinTool): - if tool_name == BuiltinTool.brave_search: - tool_name_str = WEB_SEARCH_TOOL - else: - tool_name_str = tool_name.value - else: - tool_name_str = tool_name - logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}") + logger.info(f"executing tool call: {tool_name} with args: {tool_call.arguments}") result = await self.tool_runtime_api.invoke_tool( - tool_name=tool_name_str, + tool_name=tool_name, kwargs={ "session_id": session_id, # get the arguments generated by the model and augment with toolgroup arg overrides for the agent **tool_call.arguments, - **self.tool_name_to_args.get(tool_name_str, {}), + **self.tool_name_to_args.get(tool_name, {}), }, ) - logger.debug(f"tool call {tool_name_str} completed with result: {result}") + logger.debug(f"tool call {tool_name} completed with result: {result}") return result async def handle_documents( diff --git a/llama_stack/providers/inline/inference/vllm/openai_utils.py b/llama_stack/providers/inline/inference/vllm/openai_utils.py index 90b5398f9..73a4fcbbe 100644 --- a/llama_stack/providers/inline/inference/vllm/openai_utils.py +++ b/llama_stack/providers/inline/inference/vllm/openai_utils.py @@ -16,7 +16,7 @@ from llama_stack.apis.inference import ( ToolChoice, UserMessage, ) -from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition +from llama_stack.models.llama.datatypes import ToolDefinition, ToolType from llama_stack.providers.utils.inference.openai_compat import ( convert_message_to_openai_dict, get_sampling_options, @@ -65,7 +65,7 @@ def _llama_stack_tools_to_openai_tools( result = [] for t in tools: - if isinstance(t.tool_name, BuiltinTool): + if t.type != ToolType.function.value: raise NotImplementedError("Built-in tools not yet implemented") if t.parameters is None: parameters = None diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index eda1a179c..acdc79ca6 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -45,7 +45,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model, ModelType -from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall +from llama_stack.models.llama.datatypes import StopReason, ToolCall from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( @@ -110,6 +110,8 @@ def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict] for tool in tools: properties = {} compat_required = [] + + tool_name = tool.name if tool.parameters: for tool_key, tool_param in tool.parameters.items(): properties[tool_key] = {"type": tool_param.param_type} @@ -120,12 +122,6 @@ def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict] if tool_param.required: compat_required.append(tool_key) - # The tool.tool_name can be a str or a BuiltinTool enum. If - # it's the latter, convert to a string. - tool_name = tool.tool_name - if isinstance(tool_name, BuiltinTool): - tool_name = tool_name.value - compat_tool = { "type": "function", "function": { diff --git a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py index 78b47eb56..56f8f1146 100644 --- a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py +++ b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py @@ -17,7 +17,6 @@ from llama_stack.apis.tools import ( ToolRuntime, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.models.llama.datatypes import BuiltinTool from llama_stack.providers.datatypes import ToolsProtocolPrivate from .config import BraveSearchToolConfig @@ -61,7 +60,6 @@ class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequest parameter_type="string", ) ], - built_in_type=BuiltinTool.brave_search, ) ] diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index b264c7312..46bed694b 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -80,12 +80,12 @@ from llama_stack.apis.inference import ( UserMessage, ) from llama_stack.models.llama.datatypes import ( - BuiltinTool, GreedySamplingStrategy, SamplingParams, StopReason, ToolCall, ToolDefinition, + ToolType, TopKSamplingStrategy, TopPSamplingStrategy, ) @@ -271,7 +271,7 @@ def process_chat_completion_response( else: # only return tool_calls if provided in the request new_tool_calls = [] - request_tools = {t.tool_name: t for t in request.tools} + request_tools = {t.name: t for t in request.tools} for t in raw_message.tool_calls: if t.tool_name in request_tools: new_tool_calls.append(t) @@ -423,7 +423,7 @@ async def process_chat_completion_stream_response( ) ) - request_tools = {t.tool_name: t for t in request.tools} + request_tools = {t.name: t for t in request.tools} for tool_call in message.tool_calls: if tool_call.tool_name in request_tools: yield ChatCompletionResponseStreamChunk( @@ -574,7 +574,7 @@ async def convert_message_to_openai_dict_new( OpenAIChatCompletionMessageToolCall( id=tool.call_id, function=OpenAIFunction( - name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value), + name=tool.tool_name, arguments=json.dumps(tool.arguments), ), type="function", @@ -638,7 +638,7 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict: Convert a ToolDefinition to an OpenAI API-compatible dictionary. ToolDefinition: - tool_name: str | BuiltinTool + tool_name: str description: Optional[str] parameters: Optional[Dict[str, ToolParamDefinition]] @@ -677,10 +677,7 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict: } function = out["function"] - if isinstance(tool.tool_name, BuiltinTool): - function.update(name=tool.tool_name.value) # TODO(mf): is this sufficient? - else: - function.update(name=tool.tool_name) + function.update(name=tool.name) if tool.description: function.update(description=tool.description) @@ -761,6 +758,7 @@ def _convert_openai_tool_calls( return [ ToolCall( + type=ToolType.function, call_id=call.id, tool_name=call.function.name, arguments=json.loads(call.function.arguments), @@ -975,6 +973,7 @@ async def convert_openai_chat_completion_stream( try: arguments = json.loads(buffer["arguments"]) tool_call = ToolCall( + type=ToolType.function, call_id=buffer["call_id"], tool_name=buffer["name"], arguments=arguments, diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 1edf445c0..938c49eef 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -43,6 +43,7 @@ from llama_stack.models.llama.datatypes import ( Role, StopReason, ToolPromptFormat, + ToolType, is_multimodal, ) from llama_stack.models.llama.llama3.chat_format import ChatFormat @@ -374,8 +375,8 @@ def augment_messages_for_tools_llama_3_1( messages.append(SystemMessage(content=sys_content)) - has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools) - if has_custom_tools: + custom_tools = [t for t in request.tools if t.type == ToolType.function.value] + if custom_tools: fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json if fmt == ToolPromptFormat.json: tool_gen = JsonCustomToolGenerator() @@ -384,7 +385,6 @@ def augment_messages_for_tools_llama_3_1( else: raise ValueError(f"Non supported ToolPromptFormat {fmt}") - custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)] custom_template = tool_gen.gen(custom_tools) messages.append(UserMessage(content=custom_template.render())) @@ -407,7 +407,7 @@ def augment_messages_for_tools_llama_3_2( sys_content = "" custom_tools, builtin_tools = [], [] for t in request.tools: - if isinstance(t.tool_name, str): + if t.type == ToolType.function.value: custom_tools.append(t) else: builtin_tools.append(t) @@ -419,7 +419,7 @@ def augment_messages_for_tools_llama_3_2( sys_content += tool_template.render() sys_content += "\n" - custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)] + custom_tools = [t for t in request.tools if t.type == ToolType.function.value] if custom_tools: fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list if fmt != ToolPromptFormat.python_list: diff --git a/tests/integration/agents/test_agents.py b/tests/integration/agents/test_agents.py index 7011dc02d..3a2c57fd4 100644 --- a/tests/integration/agents/test_agents.py +++ b/tests/integration/agents/test_agents.py @@ -188,16 +188,12 @@ def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent } ], session_id=session_id, + stream=False, ) - logs = [str(log) for log in AgentEventLogger().log(response) if log is not None] - logs_str = "".join(logs) - - assert "tool_execution>" in logs_str - assert "Tool:brave_search Response:" in logs_str - assert "mark zuckerberg" in logs_str.lower() - if len(agent_config["output_shields"]) > 0: - assert "No Violation" in logs_str + tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution") + assert tool_execution_step.tool_calls[0].tool_name == "web_search" + assert "mark zuckerberg" in response.output_message.content.lower() def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, agent_config): diff --git a/tests/integration/inference/test_text_inference.py b/tests/integration/inference/test_text_inference.py index c9649df60..62bbd536d 100644 --- a/tests/integration/inference/test_text_inference.py +++ b/tests/integration/inference/test_text_inference.py @@ -284,6 +284,7 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_ "test_case", [ "inference:chat_completion:tool_calling", + "inference:chat_completion:tool_calling_deprecated", ], ) def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_models, text_model_id, test_case): @@ -300,7 +301,9 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_mo assert response.completion_message.role == "assistant" assert len(response.completion_message.tool_calls) == 1 - assert response.completion_message.tool_calls[0].tool_name == tc["tools"][0]["tool_name"] + assert response.completion_message.tool_calls[0].tool_name == ( + tc["tools"][0]["tool_name"] if "tool_name" in tc["tools"][0] else tc["tools"][0]["name"] + ) assert response.completion_message.tool_calls[0].arguments == tc["expected"] @@ -334,7 +337,7 @@ def test_text_chat_completion_with_tool_calling_and_streaming(client_with_models stream=True, ) tool_invocation_content = extract_tool_invocation_content(response) - expected_tool_name = tc["tools"][0]["tool_name"] + expected_tool_name = tc["tools"][0]["tool_name"] if "tool_name" in tc["tools"][0] else tc["tools"][0]["name"] expected_argument = tc["expected"] assert tool_invocation_content == f"[{expected_tool_name}, {expected_argument}]" @@ -358,7 +361,7 @@ def test_text_chat_completion_with_tool_choice_required(client_with_models, text stream=True, ) tool_invocation_content = extract_tool_invocation_content(response) - expected_tool_name = tc["tools"][0]["tool_name"] + expected_tool_name = tc["tools"][0]["tool_name"] if "tool_name" in tc["tools"][0] else tc["tools"][0]["name"] expected_argument = tc["expected"] assert tool_invocation_content == f"[{expected_tool_name}, {expected_argument}]" @@ -432,14 +435,11 @@ def test_text_chat_completion_tool_calling_tools_not_in_request( ): tc = TestCase(test_case) - # TODO: more dynamic lookup on tool_prompt_format for model family - tool_prompt_format = "json" if "3.1" in text_model_id else "python_list" request = { "model_id": text_model_id, "messages": tc["messages"], "tools": tc["tools"], "tool_choice": "auto", - "tool_prompt_format": tool_prompt_format, "stream": streaming, } @@ -457,3 +457,30 @@ def test_text_chat_completion_tool_calling_tools_not_in_request( else: for tc in response.completion_message.tool_calls: assert tc.tool_name == "get_object_namespace_list" + + +@pytest.mark.parametrize( + "test_case", + [ + "inference:chat_completion:tool_calling_builtin_web_search", + "inference:chat_completion:tool_calling_builtin_brave_search", + "inference:chat_completion:tool_calling_builtin_code_interpreter", + "inference:chat_completion:tool_calling_builtin_code_interpreter_deprecated", + ], +) +def test_text_chat_completion_tool_calling_builtin(client_with_models, text_model_id, test_case): + tc = TestCase(test_case) + + request = { + "model_id": text_model_id, + "messages": tc["messages"], + "tools": tc["tools"], + "tool_choice": "auto", + "stream": False, + } + + response = client_with_models.inference.chat_completion(**request) + + for tool_call in response.completion_message.tool_calls: + print(tool_call) + assert tool_call.tool_name == tc["expected"] diff --git a/tests/integration/test_cases/inference/chat_completion.json b/tests/integration/test_cases/inference/chat_completion.json index e87c046b0..3b47b0816 100644 --- a/tests/integration/test_cases/inference/chat_completion.json +++ b/tests/integration/test_cases/inference/chat_completion.json @@ -49,7 +49,7 @@ "expected": "Washington" } }, - "tool_calling": { + "tool_calling_deprecated": { "data": { "messages": [ {"role": "system", "content": "Pretend you are a weather assistant."}, @@ -72,6 +72,30 @@ } } }, + "tool_calling": { + "data": { + "messages": [ + {"role": "system", "content": "Pretend you are a weather assistant."}, + {"role": "user", "content": "What's the weather like in San Francisco?"} + ], + "tools": [ + { + "type": "function", + "name": "get_weather", + "description": "Get the current weather", + "parameters": { + "location": { + "param_type": "string", + "description": "The city and state (both required), e.g. San Francisco, CA." + } + } + } + ], + "expected": { + "location": "San Francisco, CA" + } + } + }, "sample_messages_tool_calling": { "data": { "messages": [ @@ -128,6 +152,56 @@ } } }, + "tool_calling_builtin_web_search": { + "data": { + "messages": [ + {"role": "system", "content": "You are a helpful assistant. Use available tools to answer the question."}, + {"role": "user", "content": "What's the weather like in San Francisco?"} + ], + "tools": [ + { + "type": "web_search" + } + ], + "expected": "web_search" + } + }, + "tool_calling_builtin_brave_search": { + "data": { + "messages": [ + {"role": "system", "content": "You are a helpful assistant. Use available tools to answer the question."}, + {"role": "user", "content": "What's the weather like in San Francisco?"} + ], + "tools": [{ + "tool_name": "brave_search" + }], + "expected": "web_search" + } + }, + "tool_calling_builtin_code_interpreter_deprecated": { + "data": { + "messages": [ + {"role": "system", "content": "You are a helpful assistant. Use available tools to answer the question."}, + {"role": "user", "content": "plot log(x) from -10 to 10"} + ], + "tools": [{ + "tool_name": "code_interpreter" + }], + "expected": "code_interpreter" + } + }, + "tool_calling_builtin_code_interpreter": { + "data": { + "messages": [ + {"role": "system", "content": "You are a helpful assistant. Use available tools to answer the question."}, + {"role": "user", "content": "plot log(x) from -10 to 10"} + ], + "tools": [{ + "type": "code_interpreter" + }], + "expected": "code_interpreter" + } + }, "tool_calling_tools_absent": { "data": { "messages": [ @@ -146,6 +220,7 @@ "tool_calls": [ { "call_id": "1", + "type": "function", "tool_name": "get_object_namespace_list", "arguments": { "kind": "pod", diff --git a/tests/unit/models/test_prompt_adapter.py b/tests/unit/models/test_prompt_adapter.py index 0e2780e50..b2c1ae657 100644 --- a/tests/unit/models/test_prompt_adapter.py +++ b/tests/unit/models/test_prompt_adapter.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import asyncio -import unittest + +import pytest from llama_stack.apis.inference import ( ChatCompletionRequest, @@ -17,10 +18,11 @@ from llama_stack.apis.inference import ( UserMessage, ) from llama_stack.models.llama.datatypes import ( - BuiltinTool, - ToolDefinition, + CodeInterpreterTool, + FunctionTool, ToolParamDefinition, ToolPromptFormat, + WebSearchTool, ) from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_messages, @@ -31,258 +33,269 @@ MODEL = "Llama3.1-8B-Instruct" MODEL3_2 = "Llama3.2-3B-Instruct" -class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): - async def asyncSetUp(self): - asyncio.get_running_loop().set_debug(False) +@pytest.fixture(autouse=True) +def setup_loop(): + loop = asyncio.get_event_loop() + loop.set_debug(False) + return loop - async def test_system_default(self): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - self.assertEqual(len(messages), 2) - self.assertEqual(messages[-1].content, content) - self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) - async def test_system_builtin_only(self): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition(tool_name=BuiltinTool.brave_search), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - self.assertEqual(len(messages), 2) - self.assertEqual(messages[-1].content, content) - self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) - self.assertTrue("Tools: brave_search" in messages[0].content) +@pytest.mark.asyncio +async def test_system_default(): + content = "Hello !" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + UserMessage(content=content), + ], + ) + messages = chat_completion_request_to_messages(request, MODEL) + assert len(messages) == 2 + assert messages[-1].content == content + assert "Cutting Knowledge Date: December 2023" in messages[0].content - async def test_system_custom_only(self): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - tools=[ - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - parameters={ - "param1": ToolParamDefinition( - param_type="str", - description="param1 description", - required=True, - ), - }, - ) - ], - tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json), - ) - messages = chat_completion_request_to_messages(request, MODEL) - self.assertEqual(len(messages), 3) - self.assertTrue("Environment: ipython" in messages[0].content) - self.assertTrue("Return function calls in JSON format" in messages[1].content) - self.assertEqual(messages[-1].content, content) +@pytest.mark.asyncio +async def test_system_builtin_only(): + content = "Hello !" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + UserMessage(content=content), + ], + tools=[ + CodeInterpreterTool(), + WebSearchTool(), + ], + ) + messages = chat_completion_request_to_messages(request, MODEL) + assert len(messages) == 2 + assert messages[-1].content == content + assert "Cutting Knowledge Date: December 2023" in messages[0].content + assert "Tools: brave_search" in messages[0].content - async def test_system_custom_and_builtin(self): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition(tool_name=BuiltinTool.brave_search), - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - parameters={ - "param1": ToolParamDefinition( - param_type="str", - description="param1 description", - required=True, - ), - }, - ), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - self.assertEqual(len(messages), 3) - self.assertTrue("Environment: ipython" in messages[0].content) - self.assertTrue("Tools: brave_search" in messages[0].content) +@pytest.mark.asyncio +async def test_system_custom_only(): + content = "Hello !" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + UserMessage(content=content), + ], + tools=[ + FunctionTool( + name="custom1", + description="custom1 tool", + parameters={ + "param1": ToolParamDefinition( + param_type="str", + description="param1 description", + required=True, + ), + }, + ) + ], + tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json), + ) + messages = chat_completion_request_to_messages(request, MODEL) + assert len(messages) == 3 + assert "Environment: ipython" in messages[0].content + assert "Return function calls in JSON format" in messages[1].content + assert messages[-1].content == content - self.assertTrue("Return function calls in JSON format" in messages[1].content) - self.assertEqual(messages[-1].content, content) - async def test_completion_message_encoding(self): - request = ChatCompletionRequest( - model=MODEL3_2, - messages=[ - UserMessage(content="hello"), - CompletionMessage( - content="", - stop_reason=StopReason.end_of_turn, - tool_calls=[ - ToolCall( - tool_name="custom1", - arguments={"param1": "value1"}, - call_id="123", - ) - ], - ), - ], - tools=[ - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - parameters={ - "param1": ToolParamDefinition( - param_type="str", - description="param1 description", - required=True, - ), - }, - ), - ], - tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list), - ) - prompt = await chat_completion_request_to_prompt(request, request.model) - self.assertIn('[custom1(param1="value1")]', prompt) - - request.model = MODEL - request.tool_config.tool_prompt_format = ToolPromptFormat.json - prompt = await chat_completion_request_to_prompt(request, request.model) - self.assertIn( - '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', - prompt, - ) - - async def test_user_provided_system_message(self): - content = "Hello !" - system_prompt = "You are a pirate" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - self.assertEqual(len(messages), 2, messages) - self.assertTrue(messages[0].content.endswith(system_prompt)) - - self.assertEqual(messages[-1].content, content) - - async def test_repalce_system_message_behavior_builtin_tools(self): - content = "Hello !" - system_prompt = "You are a pirate" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ], - tool_config=ToolConfig( - tool_choice="auto", - tool_prompt_format="python_list", - system_message_behavior="replace", +@pytest.mark.asyncio +async def test_system_custom_and_builtin(): + content = "Hello !" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + UserMessage(content=content), + ], + tools=[ + CodeInterpreterTool(), + WebSearchTool(), + FunctionTool( + name="custom1", + description="custom1 tool", + parameters={ + "param1": ToolParamDefinition( + param_type="str", + description="param1 description", + required=True, + ), + }, ), - ) - messages = chat_completion_request_to_messages(request, MODEL3_2) - self.assertEqual(len(messages), 2, messages) - self.assertTrue(messages[0].content.endswith(system_prompt)) - self.assertIn("Environment: ipython", messages[0].content) - self.assertEqual(messages[-1].content, content) + ], + ) + messages = chat_completion_request_to_messages(request, MODEL) + assert len(messages) == 3 + assert "Environment: ipython" in messages[0].content + assert "Tools: brave_search" in messages[0].content + assert "Return function calls in JSON format" in messages[1].content + assert messages[-1].content == content - async def test_repalce_system_message_behavior_custom_tools(self): - content = "Hello !" - system_prompt = "You are a pirate" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - parameters={ - "param1": ToolParamDefinition( - param_type="str", - description="param1 description", - required=True, - ), - }, - ), - ], - tool_config=ToolConfig( - tool_choice="auto", - tool_prompt_format="python_list", - system_message_behavior="replace", + +@pytest.mark.asyncio +async def test_completion_message_encoding(): + request = ChatCompletionRequest( + model=MODEL3_2, + messages=[ + UserMessage(content="hello"), + CompletionMessage( + content="", + stop_reason=StopReason.end_of_turn, + tool_calls=[ + ToolCall( + type="function", + tool_name="custom1", + arguments={"param1": "value1"}, + call_id="123", + ) + ], ), - ) - messages = chat_completion_request_to_messages(request, MODEL3_2) - - self.assertEqual(len(messages), 2, messages) - self.assertTrue(messages[0].content.endswith(system_prompt)) - self.assertIn("Environment: ipython", messages[0].content) - self.assertEqual(messages[-1].content, content) - - async def test_replace_system_message_behavior_custom_tools_with_template(self): - content = "Hello !" - system_prompt = "You are a pirate {{ function_description }}" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - parameters={ - "param1": ToolParamDefinition( - param_type="str", - description="param1 description", - required=True, - ), - }, - ), - ], - tool_config=ToolConfig( - tool_choice="auto", - tool_prompt_format="python_list", - system_message_behavior="replace", + ], + tools=[ + FunctionTool( + name="custom1", + description="custom1 tool", + parameters={ + "param1": ToolParamDefinition( + param_type="str", + description="param1 description", + required=True, + ), + }, ), - ) - messages = chat_completion_request_to_messages(request, MODEL3_2) + ], + tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list), + ) + prompt = await chat_completion_request_to_prompt(request, request.model) + assert '[custom1(param1="value1")]' in prompt - self.assertEqual(len(messages), 2, messages) - self.assertIn("Environment: ipython", messages[0].content) - self.assertIn("You are a pirate", messages[0].content) - # function description is present in the system prompt - self.assertIn('"name": "custom1"', messages[0].content) - self.assertEqual(messages[-1].content, content) + request.model = MODEL + request.tool_config.tool_prompt_format = ToolPromptFormat.json + prompt = await chat_completion_request_to_prompt(request, request.model) + assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt + + +@pytest.mark.asyncio +async def test_user_provided_system_message(): + content = "Hello !" + system_prompt = "You are a pirate" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + SystemMessage(content=system_prompt), + UserMessage(content=content), + ], + tools=[ + CodeInterpreterTool(), + ], + ) + messages = chat_completion_request_to_messages(request, MODEL) + assert len(messages) == 2 + assert messages[0].content.endswith(system_prompt) + assert messages[-1].content == content + + +@pytest.mark.asyncio +async def test_repalce_system_message_behavior_builtin_tools(): + content = "Hello !" + system_prompt = "You are a pirate" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + SystemMessage(content=system_prompt), + UserMessage(content=content), + ], + tools=[ + CodeInterpreterTool(), + ], + tool_config=ToolConfig( + tool_choice="auto", + tool_prompt_format="python_list", + system_message_behavior="replace", + ), + ) + messages = chat_completion_request_to_messages(request, MODEL3_2) + assert len(messages) == 2 + assert messages[0].content.endswith(system_prompt) + assert "Environment: ipython" in messages[0].content + assert messages[-1].content == content + + +@pytest.mark.asyncio +async def test_repalce_system_message_behavior_custom_tools(): + content = "Hello !" + system_prompt = "You are a pirate" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + SystemMessage(content=system_prompt), + UserMessage(content=content), + ], + tools=[ + CodeInterpreterTool(), + FunctionTool( + name="custom1", + description="custom1 tool", + parameters={ + "param1": ToolParamDefinition( + param_type="str", + description="param1 description", + required=True, + ), + }, + ), + ], + tool_config=ToolConfig( + tool_choice="auto", + tool_prompt_format="python_list", + system_message_behavior="replace", + ), + ) + messages = chat_completion_request_to_messages(request, MODEL3_2) + assert len(messages) == 2 + assert messages[0].content.endswith(system_prompt) + assert "Environment: ipython" in messages[0].content + assert messages[-1].content == content + + +@pytest.mark.asyncio +async def test_replace_system_message_behavior_custom_tools_with_template(): + content = "Hello !" + system_prompt = "You are a pirate {{ function_description }}" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + SystemMessage(content=system_prompt), + UserMessage(content=content), + ], + tools=[ + CodeInterpreterTool(), + FunctionTool( + name="custom1", + description="custom1 tool", + parameters={ + "param1": ToolParamDefinition( + param_type="str", + description="param1 description", + required=True, + ), + }, + ), + ], + tool_config=ToolConfig( + tool_choice="auto", + tool_prompt_format="python_list", + system_message_behavior="replace", + ), + ) + messages = chat_completion_request_to_messages(request, MODEL3_2) + assert len(messages) == 2 + assert "Environment: ipython" in messages[0].content + assert "You are a pirate" in messages[0].content + assert '"name": "custom1"' in messages[0].content + assert messages[-1].content == content diff --git a/tests/unit/models/test_system_prompts.py b/tests/unit/models/test_system_prompts.py index 1f4ccc7e3..e068ce775 100644 --- a/tests/unit/models/test_system_prompts.py +++ b/tests/unit/models/test_system_prompts.py @@ -33,7 +33,7 @@ class PromptTemplateTests(unittest.TestCase): if not example: continue for tool in example: - assert tool.tool_name in text + assert tool.name in text def test_system_default(self): generator = SystemDefaultGenerator()