diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 8a46a89ad..2274f86e7 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -3798,6 +3798,21 @@ ], "title": "AppendRowsRequest" }, + "CodeInterpreterTool": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "code_interpreter", + "default": "code_interpreter" + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "CodeInterpreterTool" + }, "CompletionMessage": { "type": "object", "properties": { @@ -3837,6 +3852,34 @@ "title": "CompletionMessage", "description": "A message containing the model's (assistant) response in a chat conversation." }, + "FunctionTool": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "function", + "default": "function" + }, + "name": { + "type": "string" + }, + "description": { + "type": "string" + }, + "parameters": { + "type": "object", + "additionalProperties": { + "$ref": "#/components/schemas/ToolParamDefinition" + } + } + }, + "additionalProperties": false, + "required": [ + "type", + "name" + ], + "title": "FunctionTool" + }, "GrammarResponseFormat": { "type": "object", "properties": { @@ -4138,25 +4181,21 @@ "ToolCall": { "type": "object", "properties": { + "type": { + "type": "string", + "enum": [ + "function", + "web_search", + "wolfram_alpha", + "code_interpreter" + ], + "title": "ToolType" + }, "call_id": { "type": "string" }, "tool_name": { - "oneOf": [ - { - "type": "string", - "enum": [ - "brave_search", - "wolfram_alpha", - "photogen", - "code_interpreter" - ], - "title": "BuiltinTool" - }, - { - "type": "string" - } - ] + "type": "string" }, "arguments": { "oneOf": [ @@ -4237,48 +4276,13 @@ }, "additionalProperties": false, "required": [ + "type", "call_id", "tool_name", "arguments" ], "title": "ToolCall" }, - "ToolDefinition": { - "type": "object", - "properties": { - "tool_name": { - "oneOf": [ - { - "type": "string", - "enum": [ - "brave_search", - "wolfram_alpha", - "photogen", - "code_interpreter" - ], - "title": "BuiltinTool" - }, - { - "type": "string" - } - ] - }, - "description": { - "type": "string" - }, - "parameters": { - "type": "object", - "additionalProperties": { - "$ref": "#/components/schemas/ToolParamDefinition" - } - } - }, - "additionalProperties": false, - "required": [ - "tool_name" - ], - "title": "ToolDefinition" - }, "ToolParamDefinition": { "type": "object", "properties": { @@ -4428,6 +4432,36 @@ "title": "UserMessage", "description": "A message from the user in a chat conversation." }, + "WebSearchTool": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "web_search", + "default": "web_search" + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "WebSearchTool" + }, + "WolframAlphaTool": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "wolfram_alpha", + "default": "wolfram_alpha" + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "WolframAlphaTool" + }, "BatchChatCompletionRequest": { "type": "object", "properties": { @@ -4449,7 +4483,29 @@ "tools": { "type": "array", "items": { - "$ref": "#/components/schemas/ToolDefinition" + "oneOf": [ + { + "$ref": "#/components/schemas/WebSearchTool" + }, + { + "$ref": "#/components/schemas/WolframAlphaTool" + }, + { + "$ref": "#/components/schemas/CodeInterpreterTool" + }, + { + "$ref": "#/components/schemas/FunctionTool" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "web_search": "#/components/schemas/WebSearchTool", + "wolfram_alpha": "#/components/schemas/WolframAlphaTool", + "code_interpreter": "#/components/schemas/CodeInterpreterTool", + "function": "#/components/schemas/FunctionTool" + } + } } }, "tool_choice": { @@ -4734,6 +4790,41 @@ "title": "ToolConfig", "description": "Configuration for tool use." }, + "ToolDefinitionDeprecated": { + "type": "object", + "properties": { + "tool_name": { + "oneOf": [ + { + "type": "string", + "enum": [ + "brave_search", + "wolfram_alpha", + "code_interpreter" + ], + "title": "BuiltinTool" + }, + { + "type": "string" + } + ] + }, + "description": { + "type": "string" + }, + "parameters": { + "type": "object", + "additionalProperties": { + "$ref": "#/components/schemas/ToolParamDefinition" + } + } + }, + "additionalProperties": false, + "required": [ + "tool_name" + ], + "title": "ToolDefinitionDeprecated" + }, "ChatCompletionRequest": { "type": "object", "properties": { @@ -4753,10 +4844,42 @@ "description": "Parameters to control the sampling strategy" }, "tools": { - "type": "array", - "items": { - "$ref": "#/components/schemas/ToolDefinition" - }, + "oneOf": [ + { + "type": "array", + "items": { + "oneOf": [ + { + "$ref": "#/components/schemas/WebSearchTool" + }, + { + "$ref": "#/components/schemas/WolframAlphaTool" + }, + { + "$ref": "#/components/schemas/CodeInterpreterTool" + }, + { + "$ref": "#/components/schemas/FunctionTool" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "web_search": "#/components/schemas/WebSearchTool", + "wolfram_alpha": "#/components/schemas/WolframAlphaTool", + "code_interpreter": "#/components/schemas/CodeInterpreterTool", + "function": "#/components/schemas/FunctionTool" + } + } + } + }, + { + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolDefinitionDeprecated" + } + } + ], "description": "(Optional) List of tool definitions available to the model" }, "tool_choice": { @@ -5630,21 +5753,7 @@ "type": "string" }, "tool_name": { - "oneOf": [ - { - "type": "string", - "enum": [ - "brave_search", - "wolfram_alpha", - "photogen", - "code_interpreter" - ], - "title": "BuiltinTool" - }, - { - "type": "string" - } - ] + "type": "string" }, "content": { "$ref": "#/components/schemas/InterleavedContent" diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 0b8f90490..d1517bd1e 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -2607,6 +2607,17 @@ components: required: - rows title: AppendRowsRequest + CodeInterpreterTool: + type: object + properties: + type: + type: string + const: code_interpreter + default: code_interpreter + additionalProperties: false + required: + - type + title: CodeInterpreterTool CompletionMessage: type: object properties: @@ -2646,6 +2657,26 @@ components: title: CompletionMessage description: >- A message containing the model's (assistant) response in a chat conversation. + FunctionTool: + type: object + properties: + type: + type: string + const: function + default: function + name: + type: string + description: + type: string + parameters: + type: object + additionalProperties: + $ref: '#/components/schemas/ToolParamDefinition' + additionalProperties: false + required: + - type + - name + title: FunctionTool GrammarResponseFormat: type: object properties: @@ -2851,18 +2882,18 @@ components: ToolCall: type: object properties: + type: + type: string + enum: + - function + - web_search + - wolfram_alpha + - code_interpreter + title: ToolType call_id: type: string tool_name: - oneOf: - - type: string - enum: - - brave_search - - wolfram_alpha - - photogen - - code_interpreter - title: BuiltinTool - - type: string + type: string arguments: oneOf: - type: string @@ -2894,33 +2925,11 @@ components: type: string additionalProperties: false required: + - type - call_id - tool_name - arguments title: ToolCall - ToolDefinition: - type: object - properties: - tool_name: - oneOf: - - type: string - enum: - - brave_search - - wolfram_alpha - - photogen - - code_interpreter - title: BuiltinTool - - type: string - description: - type: string - parameters: - type: object - additionalProperties: - $ref: '#/components/schemas/ToolParamDefinition' - additionalProperties: false - required: - - tool_name - title: ToolDefinition ToolParamDefinition: type: object properties: @@ -3031,6 +3040,28 @@ components: title: UserMessage description: >- A message from the user in a chat conversation. + WebSearchTool: + type: object + properties: + type: + type: string + const: web_search + default: web_search + additionalProperties: false + required: + - type + title: WebSearchTool + WolframAlphaTool: + type: object + properties: + type: + type: string + const: wolfram_alpha + default: wolfram_alpha + additionalProperties: false + required: + - type + title: WolframAlphaTool BatchChatCompletionRequest: type: object properties: @@ -3047,7 +3078,18 @@ components: tools: type: array items: - $ref: '#/components/schemas/ToolDefinition' + oneOf: + - $ref: '#/components/schemas/WebSearchTool' + - $ref: '#/components/schemas/WolframAlphaTool' + - $ref: '#/components/schemas/CodeInterpreterTool' + - $ref: '#/components/schemas/FunctionTool' + discriminator: + propertyName: type + mapping: + web_search: '#/components/schemas/WebSearchTool' + wolfram_alpha: '#/components/schemas/WolframAlphaTool' + code_interpreter: '#/components/schemas/CodeInterpreterTool' + function: '#/components/schemas/FunctionTool' tool_choice: type: string enum: @@ -3272,6 +3314,28 @@ components: additionalProperties: false title: ToolConfig description: Configuration for tool use. + ToolDefinitionDeprecated: + type: object + properties: + tool_name: + oneOf: + - type: string + enum: + - brave_search + - wolfram_alpha + - code_interpreter + title: BuiltinTool + - type: string + description: + type: string + parameters: + type: object + additionalProperties: + $ref: '#/components/schemas/ToolParamDefinition' + additionalProperties: false + required: + - tool_name + title: ToolDefinitionDeprecated ChatCompletionRequest: type: object properties: @@ -3290,9 +3354,24 @@ components: description: >- Parameters to control the sampling strategy tools: - type: array - items: - $ref: '#/components/schemas/ToolDefinition' + oneOf: + - type: array + items: + oneOf: + - $ref: '#/components/schemas/WebSearchTool' + - $ref: '#/components/schemas/WolframAlphaTool' + - $ref: '#/components/schemas/CodeInterpreterTool' + - $ref: '#/components/schemas/FunctionTool' + discriminator: + propertyName: type + mapping: + web_search: '#/components/schemas/WebSearchTool' + wolfram_alpha: '#/components/schemas/WolframAlphaTool' + code_interpreter: '#/components/schemas/CodeInterpreterTool' + function: '#/components/schemas/FunctionTool' + - type: array + items: + $ref: '#/components/schemas/ToolDefinitionDeprecated' description: >- (Optional) List of tool definitions available to the model tool_choice: @@ -3939,15 +4018,7 @@ components: call_id: type: string tool_name: - oneOf: - - type: string - enum: - - brave_search - - wolfram_alpha - - photogen - - code_interpreter - title: BuiltinTool - - type: string + type: string content: $ref: '#/components/schemas/InterleavedContent' metadata: diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 7d3539dcb..c51f4b971 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -17,18 +17,18 @@ from typing import ( runtime_checkable, ) -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field from typing_extensions import Annotated from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem from llama_stack.apis.models import Model from llama_stack.apis.telemetry.telemetry import MetricResponseMixin from llama_stack.models.llama.datatypes import ( - BuiltinTool, SamplingParams, StopReason, ToolCall, ToolDefinition, + ToolDefinitionDeprecated, ToolPromptFormat, ) from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol @@ -156,23 +156,14 @@ Message = Annotated[ register_schema(Message, name="Message") +# TODO: move this to agent.py where this is used @json_schema_type class ToolResponse(BaseModel): call_id: str - tool_name: Union[BuiltinTool, str] + tool_name: str content: InterleavedContent metadata: Optional[Dict[str, Any]] = None - @field_validator("tool_name", mode="before") - @classmethod - def validate_field(cls, v): - if isinstance(v, str): - try: - return BuiltinTool(v) - except ValueError: - return v - return v - class ToolChoice(Enum): """Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model. @@ -462,7 +453,8 @@ class Inference(Protocol): model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = None, - tools: Optional[List[ToolDefinition]] = None, + # TODO: remove ToolDefinitionDeprecated in v0.1.10 + tools: Optional[List[ToolDefinition] | List[ToolDefinitionDeprecated]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = None, response_format: Optional[ResponseFormat] = None, diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 6ff36a65c..764fa2406 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -32,6 +32,7 @@ from llama_stack.apis.inference import ( ToolChoice, ToolConfig, ToolDefinition, + ToolDefinitionDeprecated, ToolPromptFormat, ) from llama_stack.apis.models import Model, ModelType @@ -54,6 +55,9 @@ from llama_stack.apis.tools import ( ) from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.log import get_logger +from llama_stack.models.llama.datatypes import ( + ToolType, +) from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.providers.datatypes import RoutingTable @@ -229,7 +233,7 @@ class InferenceRouter(Inference): messages: List[Message], sampling_params: Optional[SamplingParams] = None, response_format: Optional[ResponseFormat] = None, - tools: Optional[List[ToolDefinition]] = None, + tools: Optional[List[ToolDefinition] | List[ToolDefinitionDeprecated]] = None, tool_choice: Optional[ToolChoice] = None, tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, @@ -259,24 +263,42 @@ class InferenceRouter(Inference): params["tool_prompt_format"] = tool_prompt_format tool_config = ToolConfig(**params) - tools = tools or [] + # TODO: remove ToolDefinitionDeprecated in v0.1.10 + converted_tools = [] + for tool in tools or []: + if isinstance(tool, ToolDefinitionDeprecated): + logger.warning(f"ToolDefinitionDeprecated: {tool}, use ToolDefinition instead") + converted_tools.append(tool.to_tool_definition()) + else: + converted_tools.append(tool) + if tool_config.tool_choice == ToolChoice.none: - tools = [] + converted_tools = [] elif tool_config.tool_choice == ToolChoice.auto: pass elif tool_config.tool_choice == ToolChoice.required: pass else: # verify tool_choice is one of the tools - tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools] - if tool_config.tool_choice not in tool_names: - raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}") + for t in converted_tools: + if t.type == ToolType.function.value: + if tool_config.tool_choice == t.name: + break + elif t.type in ( + ToolType.web_search.value, + ToolType.wolfram_alpha.value, + ToolType.code_interpreter.value, + ): + if tool_config.tool_choice == t.type: + break + else: + raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {converted_tools}") params = dict( model_id=model_id, messages=messages, sampling_params=sampling_params, - tools=tools, + tools=converted_tools, tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, response_format=response_format, diff --git a/llama_stack/models/llama/datatypes.py b/llama_stack/models/llama/datatypes.py index f762eb50f..f1c280f32 100644 --- a/llama_stack/models/llama/datatypes.py +++ b/llama_stack/models/llama/datatypes.py @@ -33,10 +33,10 @@ class Role(Enum): tool = "tool" -class BuiltinTool(Enum): - brave_search = "brave_search" +class ToolType(Enum): + function = "function" + web_search = "web_search" wolfram_alpha = "wolfram_alpha" - photogen = "photogen" code_interpreter = "code_interpreter" @@ -45,8 +45,9 @@ RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]] class ToolCall(BaseModel): + type: ToolType call_id: str - tool_name: Union[BuiltinTool, str] + tool_name: str # Plan is to deprecate the Dict in favor of a JSON string # that is parsed on the client side instead of trying to manage # the recursive type here. @@ -59,12 +60,18 @@ class ToolCall(BaseModel): @field_validator("tool_name", mode="before") @classmethod def validate_field(cls, v): + # for backwards compatibility, we allow the tool name to be a string or a BuiltinTool + # TODO: remove ToolDefinitionDeprecated in v0.1.10 + tool_name = v if isinstance(v, str): try: - return BuiltinTool(v) + tool_name = BuiltinTool(v) except ValueError: - return v - return v + pass + + if isinstance(tool_name, BuiltinTool): + return tool_name.to_tool().type + return tool_name class ToolPromptFormat(Enum): @@ -151,8 +158,136 @@ class ToolParamDefinition(BaseModel): default: Optional[Any] = None +class Tool(BaseModel): + type: ToolType + + @classmethod + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + + required_properties = ["name", "description", "parameters"] + for prop in required_properties: + has_property = any(isinstance(v, property) for v in [cls.__dict__.get(prop)]) + has_field = prop in cls.__annotations__ or prop in cls.__dict__ + if not has_property and not has_field: + raise TypeError(f"Class {cls.__name__} must implement '{prop}' property or field") + + @json_schema_type -class ToolDefinition(BaseModel): +class WebSearchTool(Tool): + type: Literal[ToolType.web_search.value] = ToolType.web_search.value + + @property + def name(self) -> str: + return "web_search" + + @property + def description(self) -> str: + return "Search the web for information" + + @property + def parameters(self) -> Dict[str, ToolParamDefinition]: + return { + "query": ToolParamDefinition( + description="The query to search for", + param_type="string", + required=True, + ), + } + + +@json_schema_type +class WolframAlphaTool(Tool): + type: Literal[ToolType.wolfram_alpha.value] = ToolType.wolfram_alpha.value + + @property + def name(self) -> str: + return "wolfram_alpha" + + @property + def description(self) -> str: + return "Query WolframAlpha for computational knowledge" + + @property + def parameters(self) -> Dict[str, ToolParamDefinition]: + return { + "query": ToolParamDefinition( + description="The query to compute", + param_type="string", + required=True, + ), + } + + +@json_schema_type +class CodeInterpreterTool(Tool): + type: Literal[ToolType.code_interpreter.value] = ToolType.code_interpreter.value + + @property + def name(self) -> str: + return "code_interpreter" + + @property + def description(self) -> str: + return "Execute code" + + @property + def parameters(self) -> Dict[str, ToolParamDefinition]: + return { + "code": ToolParamDefinition( + description="The code to execute", + param_type="string", + required=True, + ), + } + + +@json_schema_type +class FunctionTool(Tool): + type: Literal[ToolType.function.value] = ToolType.function.value + name: str + description: Optional[str] = None + parameters: Optional[Dict[str, ToolParamDefinition]] = None + + @field_validator("name", mode="before") + @classmethod + def validate_name(cls, v): + if v in ToolType.__members__: + raise ValueError(f"Tool name '{v}' is a tool type and cannot be used as a name of a function tool") + return v + + +ToolDefinition = Annotated[ + Union[WebSearchTool, WolframAlphaTool, CodeInterpreterTool, FunctionTool], Field(discriminator="type") +] + + +# TODO: remove ToolDefinitionDeprecated in v0.1.10 +class BuiltinTool(Enum): + brave_search = "brave_search" + wolfram_alpha = "wolfram_alpha" + code_interpreter = "code_interpreter" + + def to_tool_type(self) -> ToolType: + if self == BuiltinTool.brave_search: + return ToolType.web_search + elif self == BuiltinTool.wolfram_alpha: + return ToolType.wolfram_alpha + elif self == BuiltinTool.code_interpreter: + return ToolType.code_interpreter + + def to_tool(self) -> WebSearchTool | WolframAlphaTool | CodeInterpreterTool: + if self == BuiltinTool.brave_search: + return WebSearchTool() + elif self == BuiltinTool.wolfram_alpha: + return WolframAlphaTool() + elif self == BuiltinTool.code_interpreter: + return CodeInterpreterTool() + + +# TODO: remove ToolDefinitionDeprecated in v0.1.10 +@json_schema_type +class ToolDefinitionDeprecated(BaseModel): tool_name: Union[BuiltinTool, str] description: Optional[str] = None parameters: Optional[Dict[str, ToolParamDefinition]] = None @@ -167,6 +302,21 @@ class ToolDefinition(BaseModel): return v return v + def to_tool_definition(self) -> ToolDefinition: + # convert to ToolDefinition + if self.tool_name == BuiltinTool.brave_search: + return WebSearchTool() + elif self.tool_name == BuiltinTool.code_interpreter: + return CodeInterpreterTool() + elif self.tool_name == BuiltinTool.wolfram_alpha: + return WolframAlphaTool() + else: + return FunctionTool( + name=self.tool_name, + description=self.description, + parameters=self.parameters, + ) + @json_schema_type class GreedySamplingStrategy(BaseModel): diff --git a/llama_stack/models/llama/llama3/chat_format.py b/llama_stack/models/llama/llama3/chat_format.py index 2862f8558..e1c44cff0 100644 --- a/llama_stack/models/llama/llama3/chat_format.py +++ b/llama_stack/models/llama/llama3/chat_format.py @@ -20,7 +20,6 @@ from typing import Dict, List, Optional, Tuple from PIL import Image as PIL_Image from llama_stack.models.llama.datatypes import ( - BuiltinTool, RawContent, RawMediaItem, RawMessage, @@ -29,6 +28,7 @@ from llama_stack.models.llama.datatypes import ( StopReason, ToolCall, ToolPromptFormat, + ToolType, ) from .tokenizer import Tokenizer @@ -127,7 +127,7 @@ class ChatFormat: if ( message.role == "assistant" and len(message.tool_calls) > 0 - and message.tool_calls[0].tool_name == BuiltinTool.code_interpreter + and message.tool_calls[0].type == ToolType.code_interpreter ): tokens.append(self.tokenizer.special_tokens["<|python_tag|>"]) @@ -194,6 +194,7 @@ class ChatFormat: stop_reason = StopReason.end_of_message tool_name = None + tool_type = ToolType.function tool_arguments = {} custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content) @@ -202,8 +203,8 @@ class ChatFormat: # Sometimes when agent has custom tools alongside builin tools # Agent responds for builtin tool calls in the format of the custom tools # This code tries to handle that case - if tool_name in BuiltinTool.__members__: - tool_name = BuiltinTool[tool_name] + if tool_name in ToolType.__members__: + tool_type = ToolType[tool_name] if isinstance(tool_arguments, dict): tool_arguments = { "query": list(tool_arguments.values())[0], @@ -215,10 +216,11 @@ class ChatFormat: tool_arguments = { "query": query, } - if tool_name in BuiltinTool.__members__: - tool_name = BuiltinTool[tool_name] + if tool_name in ToolType.__members__: + tool_type = ToolType[tool_name] elif ipython: - tool_name = BuiltinTool.code_interpreter + tool_name = ToolType.code_interpreter.value + tool_type = ToolType.code_interpreter tool_arguments = { "code": content, } @@ -228,6 +230,7 @@ class ChatFormat: call_id = str(uuid.uuid4()) tool_calls.append( ToolCall( + type=tool_type, call_id=call_id, tool_name=tool_name, arguments=tool_arguments, diff --git a/llama_stack/models/llama/llama3/interface.py b/llama_stack/models/llama/llama3/interface.py index 2579ab6c8..5abace9e2 100644 --- a/llama_stack/models/llama/llama3/interface.py +++ b/llama_stack/models/llama/llama3/interface.py @@ -17,7 +17,7 @@ from typing import List, Optional from termcolor import colored from llama_stack.models.llama.datatypes import ( - BuiltinTool, + FunctionTool, RawMessage, StopReason, ToolCall, @@ -25,7 +25,6 @@ from llama_stack.models.llama.datatypes import ( ToolPromptFormat, ) -from . import template_data from .chat_format import ChatFormat from .prompt_templates import ( BuiltinToolGenerator, @@ -150,8 +149,8 @@ class LLama31Interface: def system_messages( self, - builtin_tools: List[BuiltinTool], - custom_tools: List[ToolDefinition], + builtin_tools: List[ToolDefinition], + custom_tools: List[FunctionTool], instruction: Optional[str] = None, ) -> List[RawMessage]: messages = [] @@ -227,31 +226,3 @@ class LLama31Interface: on_col = on_colors[i % len(on_colors)] print(colored(self.tokenizer.decode([t]), "white", on_col), end="") print("\n", end="") - - -def list_jinja_templates() -> List[Template]: - return TEMPLATES - - -def render_jinja_template(name: str, tool_prompt_format: ToolPromptFormat): - by_name = {t.template_name: t for t in TEMPLATES} - if name not in by_name: - raise ValueError(f"No template found for `{name}`") - - template = by_name[name] - interface = LLama31Interface(tool_prompt_format) - - data_func = getattr(template_data, template.data_provider) - if template.role == "system": - messages = interface.system_messages(**data_func()) - elif template.role == "tool": - messages = interface.tool_response_messages(**data_func()) - elif template.role == "assistant": - messages = interface.assistant_response_messages(**data_func()) - elif template.role == "user": - messages = interface.user_message(**data_func()) - - tokens = interface.get_tokens(messages) - special_tokens = list(interface.tokenizer.special_tokens.values()) - tokens = [(interface.tokenizer.decode([t]), t in special_tokens) for t in tokens] - return template, tokens diff --git a/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py b/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py index 9da6a640e..2cd69f5d0 100644 --- a/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py +++ b/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py @@ -16,9 +16,13 @@ from datetime import datetime from typing import Any, List, Optional from llama_stack.models.llama.datatypes import ( - BuiltinTool, + CodeInterpreterTool, + FunctionTool, ToolDefinition, ToolParamDefinition, + ToolType, + WebSearchTool, + WolframAlphaTool, ) from .base import PromptTemplate, PromptTemplateGeneratorBase @@ -47,7 +51,7 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase): def _tool_breakdown(self, tools: List[ToolDefinition]): builtin_tools, custom_tools = [], [] for dfn in tools: - if isinstance(dfn.tool_name, BuiltinTool): + if dfn.type != ToolType.function.value: builtin_tools.append(dfn) else: custom_tools.append(dfn) @@ -70,7 +74,11 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase): return PromptTemplate( template_str.lstrip("\n"), { - "builtin_tools": [t.tool_name.value for t in builtin_tools], + "builtin_tools": [ + # brave_search is used in training data for web_search + t.type if t.type != ToolType.web_search.value else "brave_search" + for t in builtin_tools + ], "custom_tools": custom_tools, }, ) @@ -79,19 +87,19 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase): return [ # builtin tools [ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition(tool_name=BuiltinTool.brave_search), - ToolDefinition(tool_name=BuiltinTool.wolfram_alpha), + CodeInterpreterTool(), + WebSearchTool(), + WolframAlphaTool(), ], # only code interpretor [ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), + CodeInterpreterTool(), ], ] class JsonCustomToolGenerator(PromptTemplateGeneratorBase): - def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: + def gen(self, custom_tools: List[FunctionTool]) -> PromptTemplate: template_str = textwrap.dedent( """ Answer the user's question by making use of the following functions if needed. @@ -99,7 +107,7 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase): Here is a list of functions in JSON format: {% for t in custom_tools -%} {# manually setting up JSON because jinja sorts keys in unexpected ways -#} - {%- set tname = t.tool_name -%} + {%- set tname = t.name -%} {%- set tdesc = t.description -%} {%- set tparams = t.parameters -%} {%- set required_params = [] -%} @@ -140,8 +148,8 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase): def data_examples(self) -> List[List[ToolDefinition]]: return [ [ - ToolDefinition( - tool_name="trending_songs", + FunctionTool( + name="trending_songs", description="Returns the trending songs on a Music site", parameters={ "n": ToolParamDefinition( @@ -161,14 +169,14 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase): class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase): - def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: + def gen(self, custom_tools: List[FunctionTool]) -> PromptTemplate: template_str = textwrap.dedent( """ You have access to the following functions: {% for t in custom_tools %} {#- manually setting up JSON because jinja sorts keys in unexpected ways -#} - {%- set tname = t.tool_name -%} + {%- set tname = t.name -%} {%- set tdesc = t.description -%} {%- set modified_params = t.parameters.copy() -%} {%- for key, value in modified_params.items() -%} @@ -202,8 +210,8 @@ class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase): def data_examples(self) -> List[List[ToolDefinition]]: return [ [ - ToolDefinition( - tool_name="trending_songs", + FunctionTool( + name="trending_songs", description="Returns the trending songs on a Music site", parameters={ "n": ToolParamDefinition( @@ -240,7 +248,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801 {"function_description": self._gen_function_description(custom_tools)}, ) - def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: + def _gen_function_description(self, custom_tools: List[FunctionTool]) -> PromptTemplate: template_str = textwrap.dedent( """ If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] @@ -252,7 +260,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801 [ {% for t in tools -%} {# manually setting up JSON because jinja sorts keys in unexpected ways -#} - {%- set tname = t.tool_name -%} + {%- set tname = t.name -%} {%- set tdesc = t.description -%} {%- set tparams = t.parameters -%} {%- set required_params = [] -%} @@ -289,8 +297,8 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801 def data_examples(self) -> List[List[ToolDefinition]]: return [ [ - ToolDefinition( - tool_name="get_weather", + FunctionTool( + name="get_weather", description="Get weather info for places", parameters={ "city": ToolParamDefinition( diff --git a/llama_stack/models/llama/llama3/tool_utils.py b/llama_stack/models/llama/llama3/tool_utils.py index 71018898c..cdbf8de1e 100644 --- a/llama_stack/models/llama/llama3/tool_utils.py +++ b/llama_stack/models/llama/llama3/tool_utils.py @@ -16,7 +16,7 @@ import re from typing import Optional, Tuple from llama_stack.log import get_logger -from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat +from llama_stack.models.llama.datatypes import RecursiveType, ToolCall, ToolPromptFormat, ToolType logger = get_logger(name=__name__, category="inference") @@ -24,6 +24,12 @@ BUILTIN_TOOL_PATTERN = r'\b(?P\w+)\.call\(query="(?P[^"]*)"\)' CUSTOM_TOOL_CALL_PATTERN = re.compile(r"[^}]+)>(?P{.*?})") +# The model is trained with brave_search for web_search, so we need to map it +TOOL_NAME_MAP = { + "brave_search": ToolType.web_search.value, +} + + def is_json(s): try: parsed = json.loads(s) @@ -111,11 +117,6 @@ def parse_python_list_for_function_calls(input_string): class ToolUtils: - @staticmethod - def is_builtin_tool_call(message_body: str) -> bool: - match = re.search(ToolUtils.BUILTIN_TOOL_PATTERN, message_body) - return match is not None - @staticmethod def maybe_extract_builtin_tool_call(message_body: str) -> Optional[Tuple[str, str]]: # Find the first match in the text @@ -125,7 +126,7 @@ class ToolUtils: if match: tool_name = match.group("tool_name") query = match.group("query") - return tool_name, query + return TOOL_NAME_MAP.get(tool_name, tool_name), query else: return None @@ -143,7 +144,7 @@ class ToolUtils: tool_name = match.group("function_name") query = match.group("args") try: - return tool_name, json.loads(query.replace("'", '"')) + return TOOL_NAME_MAP.get(tool_name, tool_name), json.loads(query.replace("'", '"')) except Exception as e: print("Exception while parsing json query for custom tool call", query, e) return None @@ -152,30 +153,28 @@ class ToolUtils: if ("type" in response and response["type"] == "function") or ("name" in response): function_name = response["name"] args = response["parameters"] - return function_name, args + return TOOL_NAME_MAP.get(function_name, function_name), args else: return None elif is_valid_python_list(message_body): res = parse_python_list_for_function_calls(message_body) # FIXME: Enable multiple tool calls - return res[0] + function_name, args = res[0] + return TOOL_NAME_MAP.get(function_name, function_name), args else: return None @staticmethod def encode_tool_call(t: ToolCall, tool_prompt_format: ToolPromptFormat) -> str: - if t.tool_name == BuiltinTool.brave_search: + if t.type == ToolType.web_search: q = t.arguments["query"] return f'brave_search.call(query="{q}")' - elif t.tool_name == BuiltinTool.wolfram_alpha: + elif t.type == ToolType.wolfram_alpha: q = t.arguments["query"] return f'wolfram_alpha.call(query="{q}")' - elif t.tool_name == BuiltinTool.photogen: - q = t.arguments["query"] - return f'photogen.call(query="{q}")' - elif t.tool_name == BuiltinTool.code_interpreter: + elif t.type == ToolType.code_interpreter: return t.arguments["code"] - else: + elif t.type == ToolType.function: fname = t.tool_name if tool_prompt_format == ToolPromptFormat.json: @@ -208,3 +207,5 @@ class ToolUtils: return f"[{fname}({args_str})]" else: raise ValueError(f"Unsupported tool prompt format: {tool_prompt_format}") + else: + raise ValueError(f"Unsupported tool type: {t.type}") diff --git a/llama_stack/models/llama/llama3_1/prompts.py b/llama_stack/models/llama/llama3_1/prompts.py index 9f56bc23b..909aacbfb 100644 --- a/llama_stack/models/llama/llama3_1/prompts.py +++ b/llama_stack/models/llama/llama3_1/prompts.py @@ -15,11 +15,11 @@ import textwrap from typing import List from llama_stack.models.llama.datatypes import ( - BuiltinTool, RawMessage, StopReason, ToolCall, ToolPromptFormat, + ToolType, ) from ..prompt_format import ( @@ -184,8 +184,9 @@ def usecases() -> List[UseCase | str]: stop_reason=StopReason.end_of_message, tool_calls=[ ToolCall( + type=ToolType.wolfram_alpha, call_id="tool_call_id", - tool_name=BuiltinTool.wolfram_alpha, + tool_name=ToolType.wolfram_alpha.value, arguments={"query": "100th decimal of pi"}, ) ], diff --git a/llama_stack/models/llama/llama3_3/prompts.py b/llama_stack/models/llama/llama3_3/prompts.py index 194e4fa26..956f42d69 100644 --- a/llama_stack/models/llama/llama3_3/prompts.py +++ b/llama_stack/models/llama/llama3_3/prompts.py @@ -15,11 +15,11 @@ import textwrap from typing import List from llama_stack.models.llama.datatypes import ( - BuiltinTool, RawMessage, StopReason, ToolCall, ToolPromptFormat, + ToolType, ) from ..prompt_format import ( @@ -183,8 +183,9 @@ def usecases() -> List[UseCase | str]: stop_reason=StopReason.end_of_message, tool_calls=[ ToolCall( + type=ToolType.wolfram_alpha, call_id="tool_call_id", - tool_name=BuiltinTool.wolfram_alpha, + tool_name=ToolType.wolfram_alpha.value, arguments={"query": "100th decimal of pi"}, ) ], diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index fe1726b07..8916bae96 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -53,7 +53,7 @@ from llama_stack.apis.inference import ( SamplingParams, StopReason, SystemMessage, - ToolDefinition, + ToolDefinitionDeprecated, ToolResponse, ToolResponseMessage, UserMessage, @@ -771,7 +771,7 @@ class ChatAgent(ShieldRunnerMixin): for tool_def in self.agent_config.client_tools: if tool_name_to_def.get(tool_def.name, None): raise ValueError(f"Tool {tool_def.name} already exists") - tool_name_to_def[tool_def.name] = ToolDefinition( + tool_name_to_def[tool_def.name] = ToolDefinitionDeprecated( tool_name=tool_def.name, description=tool_def.description, parameters={ @@ -814,7 +814,7 @@ class ChatAgent(ShieldRunnerMixin): if tool_name_to_def.get(identifier, None): raise ValueError(f"Tool {identifier} already exists") if identifier: - tool_name_to_def[tool_def.identifier] = ToolDefinition( + tool_name_to_def[tool_def.identifier] = ToolDefinitionDeprecated( tool_name=identifier, description=tool_def.description, parameters={ @@ -854,30 +854,23 @@ class ChatAgent(ShieldRunnerMixin): tool_call: ToolCall, ) -> ToolInvocationResult: tool_name = tool_call.tool_name - registered_tool_names = [tool_def.tool_name for tool_def in self.tool_defs] + registered_tool_names = list(self.tool_name_to_args.keys()) if tool_name not in registered_tool_names: raise ValueError( f"Tool {tool_name} not found in provided tools, registered tools: {', '.join([str(x) for x in registered_tool_names])}" ) - if isinstance(tool_name, BuiltinTool): - if tool_name == BuiltinTool.brave_search: - tool_name_str = WEB_SEARCH_TOOL - else: - tool_name_str = tool_name.value - else: - tool_name_str = tool_name - logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}") + logger.info(f"executing tool call: {tool_name} with args: {tool_call.arguments}") result = await self.tool_runtime_api.invoke_tool( - tool_name=tool_name_str, + tool_name=tool_name, kwargs={ "session_id": session_id, # get the arguments generated by the model and augment with toolgroup arg overrides for the agent **tool_call.arguments, - **self.tool_name_to_args.get(tool_name_str, {}), + **self.tool_name_to_args.get(tool_name, {}), }, ) - logger.debug(f"tool call {tool_name_str} completed with result: {result}") + logger.debug(f"tool call {tool_name} completed with result: {result}") return result async def handle_documents( diff --git a/llama_stack/providers/inline/inference/vllm/openai_utils.py b/llama_stack/providers/inline/inference/vllm/openai_utils.py index 90b5398f9..73a4fcbbe 100644 --- a/llama_stack/providers/inline/inference/vllm/openai_utils.py +++ b/llama_stack/providers/inline/inference/vllm/openai_utils.py @@ -16,7 +16,7 @@ from llama_stack.apis.inference import ( ToolChoice, UserMessage, ) -from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition +from llama_stack.models.llama.datatypes import ToolDefinition, ToolType from llama_stack.providers.utils.inference.openai_compat import ( convert_message_to_openai_dict, get_sampling_options, @@ -65,7 +65,7 @@ def _llama_stack_tools_to_openai_tools( result = [] for t in tools: - if isinstance(t.tool_name, BuiltinTool): + if t.type != ToolType.function.value: raise NotImplementedError("Built-in tools not yet implemented") if t.parameters is None: parameters = None diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index eda1a179c..acdc79ca6 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -45,7 +45,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model, ModelType -from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall +from llama_stack.models.llama.datatypes import StopReason, ToolCall from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( @@ -110,6 +110,8 @@ def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict] for tool in tools: properties = {} compat_required = [] + + tool_name = tool.name if tool.parameters: for tool_key, tool_param in tool.parameters.items(): properties[tool_key] = {"type": tool_param.param_type} @@ -120,12 +122,6 @@ def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict] if tool_param.required: compat_required.append(tool_key) - # The tool.tool_name can be a str or a BuiltinTool enum. If - # it's the latter, convert to a string. - tool_name = tool.tool_name - if isinstance(tool_name, BuiltinTool): - tool_name = tool_name.value - compat_tool = { "type": "function", "function": { diff --git a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py index 78b47eb56..56f8f1146 100644 --- a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py +++ b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py @@ -17,7 +17,6 @@ from llama_stack.apis.tools import ( ToolRuntime, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.models.llama.datatypes import BuiltinTool from llama_stack.providers.datatypes import ToolsProtocolPrivate from .config import BraveSearchToolConfig @@ -61,7 +60,6 @@ class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequest parameter_type="string", ) ], - built_in_type=BuiltinTool.brave_search, ) ] diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index b264c7312..46bed694b 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -80,12 +80,12 @@ from llama_stack.apis.inference import ( UserMessage, ) from llama_stack.models.llama.datatypes import ( - BuiltinTool, GreedySamplingStrategy, SamplingParams, StopReason, ToolCall, ToolDefinition, + ToolType, TopKSamplingStrategy, TopPSamplingStrategy, ) @@ -271,7 +271,7 @@ def process_chat_completion_response( else: # only return tool_calls if provided in the request new_tool_calls = [] - request_tools = {t.tool_name: t for t in request.tools} + request_tools = {t.name: t for t in request.tools} for t in raw_message.tool_calls: if t.tool_name in request_tools: new_tool_calls.append(t) @@ -423,7 +423,7 @@ async def process_chat_completion_stream_response( ) ) - request_tools = {t.tool_name: t for t in request.tools} + request_tools = {t.name: t for t in request.tools} for tool_call in message.tool_calls: if tool_call.tool_name in request_tools: yield ChatCompletionResponseStreamChunk( @@ -574,7 +574,7 @@ async def convert_message_to_openai_dict_new( OpenAIChatCompletionMessageToolCall( id=tool.call_id, function=OpenAIFunction( - name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value), + name=tool.tool_name, arguments=json.dumps(tool.arguments), ), type="function", @@ -638,7 +638,7 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict: Convert a ToolDefinition to an OpenAI API-compatible dictionary. ToolDefinition: - tool_name: str | BuiltinTool + tool_name: str description: Optional[str] parameters: Optional[Dict[str, ToolParamDefinition]] @@ -677,10 +677,7 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict: } function = out["function"] - if isinstance(tool.tool_name, BuiltinTool): - function.update(name=tool.tool_name.value) # TODO(mf): is this sufficient? - else: - function.update(name=tool.tool_name) + function.update(name=tool.name) if tool.description: function.update(description=tool.description) @@ -761,6 +758,7 @@ def _convert_openai_tool_calls( return [ ToolCall( + type=ToolType.function, call_id=call.id, tool_name=call.function.name, arguments=json.loads(call.function.arguments), @@ -975,6 +973,7 @@ async def convert_openai_chat_completion_stream( try: arguments = json.loads(buffer["arguments"]) tool_call = ToolCall( + type=ToolType.function, call_id=buffer["call_id"], tool_name=buffer["name"], arguments=arguments, diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 1edf445c0..938c49eef 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -43,6 +43,7 @@ from llama_stack.models.llama.datatypes import ( Role, StopReason, ToolPromptFormat, + ToolType, is_multimodal, ) from llama_stack.models.llama.llama3.chat_format import ChatFormat @@ -374,8 +375,8 @@ def augment_messages_for_tools_llama_3_1( messages.append(SystemMessage(content=sys_content)) - has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools) - if has_custom_tools: + custom_tools = [t for t in request.tools if t.type == ToolType.function.value] + if custom_tools: fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json if fmt == ToolPromptFormat.json: tool_gen = JsonCustomToolGenerator() @@ -384,7 +385,6 @@ def augment_messages_for_tools_llama_3_1( else: raise ValueError(f"Non supported ToolPromptFormat {fmt}") - custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)] custom_template = tool_gen.gen(custom_tools) messages.append(UserMessage(content=custom_template.render())) @@ -407,7 +407,7 @@ def augment_messages_for_tools_llama_3_2( sys_content = "" custom_tools, builtin_tools = [], [] for t in request.tools: - if isinstance(t.tool_name, str): + if t.type == ToolType.function.value: custom_tools.append(t) else: builtin_tools.append(t) @@ -419,7 +419,7 @@ def augment_messages_for_tools_llama_3_2( sys_content += tool_template.render() sys_content += "\n" - custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)] + custom_tools = [t for t in request.tools if t.type == ToolType.function.value] if custom_tools: fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list if fmt != ToolPromptFormat.python_list: diff --git a/tests/integration/agents/test_agents.py b/tests/integration/agents/test_agents.py index 7011dc02d..3a2c57fd4 100644 --- a/tests/integration/agents/test_agents.py +++ b/tests/integration/agents/test_agents.py @@ -188,16 +188,12 @@ def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent } ], session_id=session_id, + stream=False, ) - logs = [str(log) for log in AgentEventLogger().log(response) if log is not None] - logs_str = "".join(logs) - - assert "tool_execution>" in logs_str - assert "Tool:brave_search Response:" in logs_str - assert "mark zuckerberg" in logs_str.lower() - if len(agent_config["output_shields"]) > 0: - assert "No Violation" in logs_str + tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution") + assert tool_execution_step.tool_calls[0].tool_name == "web_search" + assert "mark zuckerberg" in response.output_message.content.lower() def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, agent_config): diff --git a/tests/integration/inference/test_text_inference.py b/tests/integration/inference/test_text_inference.py index c9649df60..62bbd536d 100644 --- a/tests/integration/inference/test_text_inference.py +++ b/tests/integration/inference/test_text_inference.py @@ -284,6 +284,7 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_ "test_case", [ "inference:chat_completion:tool_calling", + "inference:chat_completion:tool_calling_deprecated", ], ) def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_models, text_model_id, test_case): @@ -300,7 +301,9 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_mo assert response.completion_message.role == "assistant" assert len(response.completion_message.tool_calls) == 1 - assert response.completion_message.tool_calls[0].tool_name == tc["tools"][0]["tool_name"] + assert response.completion_message.tool_calls[0].tool_name == ( + tc["tools"][0]["tool_name"] if "tool_name" in tc["tools"][0] else tc["tools"][0]["name"] + ) assert response.completion_message.tool_calls[0].arguments == tc["expected"] @@ -334,7 +337,7 @@ def test_text_chat_completion_with_tool_calling_and_streaming(client_with_models stream=True, ) tool_invocation_content = extract_tool_invocation_content(response) - expected_tool_name = tc["tools"][0]["tool_name"] + expected_tool_name = tc["tools"][0]["tool_name"] if "tool_name" in tc["tools"][0] else tc["tools"][0]["name"] expected_argument = tc["expected"] assert tool_invocation_content == f"[{expected_tool_name}, {expected_argument}]" @@ -358,7 +361,7 @@ def test_text_chat_completion_with_tool_choice_required(client_with_models, text stream=True, ) tool_invocation_content = extract_tool_invocation_content(response) - expected_tool_name = tc["tools"][0]["tool_name"] + expected_tool_name = tc["tools"][0]["tool_name"] if "tool_name" in tc["tools"][0] else tc["tools"][0]["name"] expected_argument = tc["expected"] assert tool_invocation_content == f"[{expected_tool_name}, {expected_argument}]" @@ -432,14 +435,11 @@ def test_text_chat_completion_tool_calling_tools_not_in_request( ): tc = TestCase(test_case) - # TODO: more dynamic lookup on tool_prompt_format for model family - tool_prompt_format = "json" if "3.1" in text_model_id else "python_list" request = { "model_id": text_model_id, "messages": tc["messages"], "tools": tc["tools"], "tool_choice": "auto", - "tool_prompt_format": tool_prompt_format, "stream": streaming, } @@ -457,3 +457,30 @@ def test_text_chat_completion_tool_calling_tools_not_in_request( else: for tc in response.completion_message.tool_calls: assert tc.tool_name == "get_object_namespace_list" + + +@pytest.mark.parametrize( + "test_case", + [ + "inference:chat_completion:tool_calling_builtin_web_search", + "inference:chat_completion:tool_calling_builtin_brave_search", + "inference:chat_completion:tool_calling_builtin_code_interpreter", + "inference:chat_completion:tool_calling_builtin_code_interpreter_deprecated", + ], +) +def test_text_chat_completion_tool_calling_builtin(client_with_models, text_model_id, test_case): + tc = TestCase(test_case) + + request = { + "model_id": text_model_id, + "messages": tc["messages"], + "tools": tc["tools"], + "tool_choice": "auto", + "stream": False, + } + + response = client_with_models.inference.chat_completion(**request) + + for tool_call in response.completion_message.tool_calls: + print(tool_call) + assert tool_call.tool_name == tc["expected"] diff --git a/tests/integration/test_cases/inference/chat_completion.json b/tests/integration/test_cases/inference/chat_completion.json index e87c046b0..3b47b0816 100644 --- a/tests/integration/test_cases/inference/chat_completion.json +++ b/tests/integration/test_cases/inference/chat_completion.json @@ -49,7 +49,7 @@ "expected": "Washington" } }, - "tool_calling": { + "tool_calling_deprecated": { "data": { "messages": [ {"role": "system", "content": "Pretend you are a weather assistant."}, @@ -72,6 +72,30 @@ } } }, + "tool_calling": { + "data": { + "messages": [ + {"role": "system", "content": "Pretend you are a weather assistant."}, + {"role": "user", "content": "What's the weather like in San Francisco?"} + ], + "tools": [ + { + "type": "function", + "name": "get_weather", + "description": "Get the current weather", + "parameters": { + "location": { + "param_type": "string", + "description": "The city and state (both required), e.g. San Francisco, CA." + } + } + } + ], + "expected": { + "location": "San Francisco, CA" + } + } + }, "sample_messages_tool_calling": { "data": { "messages": [ @@ -128,6 +152,56 @@ } } }, + "tool_calling_builtin_web_search": { + "data": { + "messages": [ + {"role": "system", "content": "You are a helpful assistant. Use available tools to answer the question."}, + {"role": "user", "content": "What's the weather like in San Francisco?"} + ], + "tools": [ + { + "type": "web_search" + } + ], + "expected": "web_search" + } + }, + "tool_calling_builtin_brave_search": { + "data": { + "messages": [ + {"role": "system", "content": "You are a helpful assistant. Use available tools to answer the question."}, + {"role": "user", "content": "What's the weather like in San Francisco?"} + ], + "tools": [{ + "tool_name": "brave_search" + }], + "expected": "web_search" + } + }, + "tool_calling_builtin_code_interpreter_deprecated": { + "data": { + "messages": [ + {"role": "system", "content": "You are a helpful assistant. Use available tools to answer the question."}, + {"role": "user", "content": "plot log(x) from -10 to 10"} + ], + "tools": [{ + "tool_name": "code_interpreter" + }], + "expected": "code_interpreter" + } + }, + "tool_calling_builtin_code_interpreter": { + "data": { + "messages": [ + {"role": "system", "content": "You are a helpful assistant. Use available tools to answer the question."}, + {"role": "user", "content": "plot log(x) from -10 to 10"} + ], + "tools": [{ + "type": "code_interpreter" + }], + "expected": "code_interpreter" + } + }, "tool_calling_tools_absent": { "data": { "messages": [ @@ -146,6 +220,7 @@ "tool_calls": [ { "call_id": "1", + "type": "function", "tool_name": "get_object_namespace_list", "arguments": { "kind": "pod", diff --git a/tests/unit/models/test_prompt_adapter.py b/tests/unit/models/test_prompt_adapter.py index 0e2780e50..b2c1ae657 100644 --- a/tests/unit/models/test_prompt_adapter.py +++ b/tests/unit/models/test_prompt_adapter.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import asyncio -import unittest + +import pytest from llama_stack.apis.inference import ( ChatCompletionRequest, @@ -17,10 +18,11 @@ from llama_stack.apis.inference import ( UserMessage, ) from llama_stack.models.llama.datatypes import ( - BuiltinTool, - ToolDefinition, + CodeInterpreterTool, + FunctionTool, ToolParamDefinition, ToolPromptFormat, + WebSearchTool, ) from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_messages, @@ -31,258 +33,269 @@ MODEL = "Llama3.1-8B-Instruct" MODEL3_2 = "Llama3.2-3B-Instruct" -class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): - async def asyncSetUp(self): - asyncio.get_running_loop().set_debug(False) +@pytest.fixture(autouse=True) +def setup_loop(): + loop = asyncio.get_event_loop() + loop.set_debug(False) + return loop - async def test_system_default(self): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - self.assertEqual(len(messages), 2) - self.assertEqual(messages[-1].content, content) - self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) - async def test_system_builtin_only(self): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition(tool_name=BuiltinTool.brave_search), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - self.assertEqual(len(messages), 2) - self.assertEqual(messages[-1].content, content) - self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) - self.assertTrue("Tools: brave_search" in messages[0].content) +@pytest.mark.asyncio +async def test_system_default(): + content = "Hello !" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + UserMessage(content=content), + ], + ) + messages = chat_completion_request_to_messages(request, MODEL) + assert len(messages) == 2 + assert messages[-1].content == content + assert "Cutting Knowledge Date: December 2023" in messages[0].content - async def test_system_custom_only(self): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - tools=[ - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - parameters={ - "param1": ToolParamDefinition( - param_type="str", - description="param1 description", - required=True, - ), - }, - ) - ], - tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json), - ) - messages = chat_completion_request_to_messages(request, MODEL) - self.assertEqual(len(messages), 3) - self.assertTrue("Environment: ipython" in messages[0].content) - self.assertTrue("Return function calls in JSON format" in messages[1].content) - self.assertEqual(messages[-1].content, content) +@pytest.mark.asyncio +async def test_system_builtin_only(): + content = "Hello !" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + UserMessage(content=content), + ], + tools=[ + CodeInterpreterTool(), + WebSearchTool(), + ], + ) + messages = chat_completion_request_to_messages(request, MODEL) + assert len(messages) == 2 + assert messages[-1].content == content + assert "Cutting Knowledge Date: December 2023" in messages[0].content + assert "Tools: brave_search" in messages[0].content - async def test_system_custom_and_builtin(self): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition(tool_name=BuiltinTool.brave_search), - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - parameters={ - "param1": ToolParamDefinition( - param_type="str", - description="param1 description", - required=True, - ), - }, - ), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - self.assertEqual(len(messages), 3) - self.assertTrue("Environment: ipython" in messages[0].content) - self.assertTrue("Tools: brave_search" in messages[0].content) +@pytest.mark.asyncio +async def test_system_custom_only(): + content = "Hello !" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + UserMessage(content=content), + ], + tools=[ + FunctionTool( + name="custom1", + description="custom1 tool", + parameters={ + "param1": ToolParamDefinition( + param_type="str", + description="param1 description", + required=True, + ), + }, + ) + ], + tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json), + ) + messages = chat_completion_request_to_messages(request, MODEL) + assert len(messages) == 3 + assert "Environment: ipython" in messages[0].content + assert "Return function calls in JSON format" in messages[1].content + assert messages[-1].content == content - self.assertTrue("Return function calls in JSON format" in messages[1].content) - self.assertEqual(messages[-1].content, content) - async def test_completion_message_encoding(self): - request = ChatCompletionRequest( - model=MODEL3_2, - messages=[ - UserMessage(content="hello"), - CompletionMessage( - content="", - stop_reason=StopReason.end_of_turn, - tool_calls=[ - ToolCall( - tool_name="custom1", - arguments={"param1": "value1"}, - call_id="123", - ) - ], - ), - ], - tools=[ - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - parameters={ - "param1": ToolParamDefinition( - param_type="str", - description="param1 description", - required=True, - ), - }, - ), - ], - tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list), - ) - prompt = await chat_completion_request_to_prompt(request, request.model) - self.assertIn('[custom1(param1="value1")]', prompt) - - request.model = MODEL - request.tool_config.tool_prompt_format = ToolPromptFormat.json - prompt = await chat_completion_request_to_prompt(request, request.model) - self.assertIn( - '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', - prompt, - ) - - async def test_user_provided_system_message(self): - content = "Hello !" - system_prompt = "You are a pirate" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - self.assertEqual(len(messages), 2, messages) - self.assertTrue(messages[0].content.endswith(system_prompt)) - - self.assertEqual(messages[-1].content, content) - - async def test_repalce_system_message_behavior_builtin_tools(self): - content = "Hello !" - system_prompt = "You are a pirate" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ], - tool_config=ToolConfig( - tool_choice="auto", - tool_prompt_format="python_list", - system_message_behavior="replace", +@pytest.mark.asyncio +async def test_system_custom_and_builtin(): + content = "Hello !" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + UserMessage(content=content), + ], + tools=[ + CodeInterpreterTool(), + WebSearchTool(), + FunctionTool( + name="custom1", + description="custom1 tool", + parameters={ + "param1": ToolParamDefinition( + param_type="str", + description="param1 description", + required=True, + ), + }, ), - ) - messages = chat_completion_request_to_messages(request, MODEL3_2) - self.assertEqual(len(messages), 2, messages) - self.assertTrue(messages[0].content.endswith(system_prompt)) - self.assertIn("Environment: ipython", messages[0].content) - self.assertEqual(messages[-1].content, content) + ], + ) + messages = chat_completion_request_to_messages(request, MODEL) + assert len(messages) == 3 + assert "Environment: ipython" in messages[0].content + assert "Tools: brave_search" in messages[0].content + assert "Return function calls in JSON format" in messages[1].content + assert messages[-1].content == content - async def test_repalce_system_message_behavior_custom_tools(self): - content = "Hello !" - system_prompt = "You are a pirate" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - parameters={ - "param1": ToolParamDefinition( - param_type="str", - description="param1 description", - required=True, - ), - }, - ), - ], - tool_config=ToolConfig( - tool_choice="auto", - tool_prompt_format="python_list", - system_message_behavior="replace", + +@pytest.mark.asyncio +async def test_completion_message_encoding(): + request = ChatCompletionRequest( + model=MODEL3_2, + messages=[ + UserMessage(content="hello"), + CompletionMessage( + content="", + stop_reason=StopReason.end_of_turn, + tool_calls=[ + ToolCall( + type="function", + tool_name="custom1", + arguments={"param1": "value1"}, + call_id="123", + ) + ], ), - ) - messages = chat_completion_request_to_messages(request, MODEL3_2) - - self.assertEqual(len(messages), 2, messages) - self.assertTrue(messages[0].content.endswith(system_prompt)) - self.assertIn("Environment: ipython", messages[0].content) - self.assertEqual(messages[-1].content, content) - - async def test_replace_system_message_behavior_custom_tools_with_template(self): - content = "Hello !" - system_prompt = "You are a pirate {{ function_description }}" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - parameters={ - "param1": ToolParamDefinition( - param_type="str", - description="param1 description", - required=True, - ), - }, - ), - ], - tool_config=ToolConfig( - tool_choice="auto", - tool_prompt_format="python_list", - system_message_behavior="replace", + ], + tools=[ + FunctionTool( + name="custom1", + description="custom1 tool", + parameters={ + "param1": ToolParamDefinition( + param_type="str", + description="param1 description", + required=True, + ), + }, ), - ) - messages = chat_completion_request_to_messages(request, MODEL3_2) + ], + tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list), + ) + prompt = await chat_completion_request_to_prompt(request, request.model) + assert '[custom1(param1="value1")]' in prompt - self.assertEqual(len(messages), 2, messages) - self.assertIn("Environment: ipython", messages[0].content) - self.assertIn("You are a pirate", messages[0].content) - # function description is present in the system prompt - self.assertIn('"name": "custom1"', messages[0].content) - self.assertEqual(messages[-1].content, content) + request.model = MODEL + request.tool_config.tool_prompt_format = ToolPromptFormat.json + prompt = await chat_completion_request_to_prompt(request, request.model) + assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt + + +@pytest.mark.asyncio +async def test_user_provided_system_message(): + content = "Hello !" + system_prompt = "You are a pirate" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + SystemMessage(content=system_prompt), + UserMessage(content=content), + ], + tools=[ + CodeInterpreterTool(), + ], + ) + messages = chat_completion_request_to_messages(request, MODEL) + assert len(messages) == 2 + assert messages[0].content.endswith(system_prompt) + assert messages[-1].content == content + + +@pytest.mark.asyncio +async def test_repalce_system_message_behavior_builtin_tools(): + content = "Hello !" + system_prompt = "You are a pirate" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + SystemMessage(content=system_prompt), + UserMessage(content=content), + ], + tools=[ + CodeInterpreterTool(), + ], + tool_config=ToolConfig( + tool_choice="auto", + tool_prompt_format="python_list", + system_message_behavior="replace", + ), + ) + messages = chat_completion_request_to_messages(request, MODEL3_2) + assert len(messages) == 2 + assert messages[0].content.endswith(system_prompt) + assert "Environment: ipython" in messages[0].content + assert messages[-1].content == content + + +@pytest.mark.asyncio +async def test_repalce_system_message_behavior_custom_tools(): + content = "Hello !" + system_prompt = "You are a pirate" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + SystemMessage(content=system_prompt), + UserMessage(content=content), + ], + tools=[ + CodeInterpreterTool(), + FunctionTool( + name="custom1", + description="custom1 tool", + parameters={ + "param1": ToolParamDefinition( + param_type="str", + description="param1 description", + required=True, + ), + }, + ), + ], + tool_config=ToolConfig( + tool_choice="auto", + tool_prompt_format="python_list", + system_message_behavior="replace", + ), + ) + messages = chat_completion_request_to_messages(request, MODEL3_2) + assert len(messages) == 2 + assert messages[0].content.endswith(system_prompt) + assert "Environment: ipython" in messages[0].content + assert messages[-1].content == content + + +@pytest.mark.asyncio +async def test_replace_system_message_behavior_custom_tools_with_template(): + content = "Hello !" + system_prompt = "You are a pirate {{ function_description }}" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + SystemMessage(content=system_prompt), + UserMessage(content=content), + ], + tools=[ + CodeInterpreterTool(), + FunctionTool( + name="custom1", + description="custom1 tool", + parameters={ + "param1": ToolParamDefinition( + param_type="str", + description="param1 description", + required=True, + ), + }, + ), + ], + tool_config=ToolConfig( + tool_choice="auto", + tool_prompt_format="python_list", + system_message_behavior="replace", + ), + ) + messages = chat_completion_request_to_messages(request, MODEL3_2) + assert len(messages) == 2 + assert "Environment: ipython" in messages[0].content + assert "You are a pirate" in messages[0].content + assert '"name": "custom1"' in messages[0].content + assert messages[-1].content == content diff --git a/tests/unit/models/test_system_prompts.py b/tests/unit/models/test_system_prompts.py index 1f4ccc7e3..e068ce775 100644 --- a/tests/unit/models/test_system_prompts.py +++ b/tests/unit/models/test_system_prompts.py @@ -33,7 +33,7 @@ class PromptTemplateTests(unittest.TestCase): if not example: continue for tool in example: - assert tool.tool_name in text + assert tool.name in text def test_system_default(self): generator = SystemDefaultGenerator()