diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 5a78f5bae..fb7525988 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -3714,7 +3714,7 @@ "client_tools": { "type": "array", "items": { - "$ref": "#/components/schemas/UserDefinedToolDef" + "$ref": "#/components/schemas/ToolDef" } }, "tool_choice": { @@ -3792,60 +3792,9 @@ } ] }, - "ToolParameter": { + "ToolDef": { "type": "object", "properties": { - "name": { - "type": "string" - }, - "parameter_type": { - "type": "string" - }, - "description": { - "type": "string" - }, - "required": { - "type": "boolean" - }, - "default": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - }, - "additionalProperties": false, - "required": [ - "name", - "parameter_type", - "description", - "required" - ] - }, - "UserDefinedToolDef": { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "user_defined", - "default": "user_defined" - }, "name": { "type": "string" }, @@ -3890,11 +3839,53 @@ }, "additionalProperties": false, "required": [ - "type", + "name" + ] + }, + "ToolParameter": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "parameter_type": { + "type": "string" + }, + "description": { + "type": "string" + }, + "required": { + "type": "boolean" + }, + "default": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + }, + "additionalProperties": false, + "required": [ "name", + "parameter_type", "description", - "parameters", - "metadata" + "required" ] }, "CreateAgentRequest": { @@ -4589,49 +4580,6 @@ "session_id" ] }, - "BuiltInToolDef": { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "built_in", - "default": "built_in" - }, - "built_in_type": { - "$ref": "#/components/schemas/BuiltinTool" - }, - "metadata": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - } - }, - "additionalProperties": false, - "required": [ - "type", - "built_in_type" - ] - }, "MCPToolGroupDef": { "type": "object", "properties": { @@ -4651,16 +4599,6 @@ ], "title": "A tool group that is defined by in a model context protocol server. Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information." }, - "ToolDef": { - "oneOf": [ - { - "$ref": "#/components/schemas/UserDefinedToolDef" - }, - { - "$ref": "#/components/schemas/BuiltInToolDef" - } - ] - }, "ToolGroupDef": { "oneOf": [ { @@ -7436,7 +7374,7 @@ "tool_group_id": { "type": "string" }, - "tool_group": { + "tool_group_def": { "$ref": "#/components/schemas/ToolGroupDef" }, "provider_id": { @@ -7446,7 +7384,7 @@ "additionalProperties": false, "required": [ "tool_group_id", - "tool_group" + "tool_group_def" ] }, "RunEvalRequest": { @@ -8098,10 +8036,6 @@ "name": "BenchmarkEvalTaskConfig", "description": "" }, - { - "name": "BuiltInToolDef", - "description": "" - }, { "name": "BuiltinTool", "description": "" @@ -8708,10 +8642,6 @@ "name": "UnstructuredLogEvent", "description": "" }, - { - "name": "UserDefinedToolDef", - "description": "" - }, { "name": "UserDefinedToolGroupDef", "description": "" @@ -8792,7 +8722,6 @@ "BatchCompletionRequest", "BatchCompletionResponse", "BenchmarkEvalTaskConfig", - "BuiltInToolDef", "BuiltinTool", "CancelTrainingJobRequest", "ChatCompletionRequest", @@ -8931,7 +8860,6 @@ "UnregisterModelRequest", "UnregisterToolGroupRequest", "UnstructuredLogEvent", - "UserDefinedToolDef", "UserDefinedToolGroupDef", "UserMessage", "VectorMemoryBank", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 72093b436..0937d8722 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -19,7 +19,7 @@ components: properties: client_tools: items: - $ref: '#/components/schemas/UserDefinedToolDef' + $ref: '#/components/schemas/ToolDef' type: array enable_session_persistence: type: boolean @@ -396,29 +396,6 @@ components: - type - eval_candidate type: object - BuiltInToolDef: - additionalProperties: false - properties: - built_in_type: - $ref: '#/components/schemas/BuiltinTool' - metadata: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - type: - const: built_in - default: built_in - type: string - required: - - type - - built_in_type - type: object BuiltinTool: enum: - brave_search @@ -1929,13 +1906,13 @@ components: properties: provider_id: type: string - tool_group: + tool_group_def: $ref: '#/components/schemas/ToolGroupDef' tool_group_id: type: string required: - tool_group_id - - tool_group + - tool_group_def type: object ResponseFormat: oneOf: @@ -2716,9 +2693,32 @@ components: - required type: string ToolDef: - oneOf: - - $ref: '#/components/schemas/UserDefinedToolDef' - - $ref: '#/components/schemas/BuiltInToolDef' + additionalProperties: false + properties: + description: + type: string + metadata: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + name: + type: string + parameters: + items: + $ref: '#/components/schemas/ToolParameter' + type: array + tool_prompt_format: + $ref: '#/components/schemas/ToolPromptFormat' + default: json + required: + - name + type: object ToolDefinition: additionalProperties: false properties: @@ -3087,41 +3087,6 @@ components: - message - severity type: object - UserDefinedToolDef: - additionalProperties: false - properties: - description: - type: string - metadata: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - name: - type: string - parameters: - items: - $ref: '#/components/schemas/ToolParameter' - type: array - tool_prompt_format: - $ref: '#/components/schemas/ToolPromptFormat' - default: json - type: - const: user_defined - default: user_defined - type: string - required: - - type - - name - - description - - parameters - - metadata - type: object UserDefinedToolGroupDef: additionalProperties: false properties: @@ -4823,8 +4788,6 @@ tags: - description: name: BenchmarkEvalTaskConfig -- description: - name: BuiltInToolDef - description: name: BuiltinTool - description: name: UnstructuredLogEvent -- description: - name: UserDefinedToolDef - description: name: UserDefinedToolGroupDef @@ -5316,7 +5276,6 @@ x-tagGroups: - BatchCompletionRequest - BatchCompletionResponse - BenchmarkEvalTaskConfig - - BuiltInToolDef - BuiltinTool - CancelTrainingJobRequest - ChatCompletionRequest @@ -5455,7 +5414,6 @@ x-tagGroups: - UnregisterModelRequest - UnregisterToolGroupRequest - UnstructuredLogEvent - - UserDefinedToolDef - UserDefinedToolGroupDef - UserMessage - VectorMemoryBank diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index acf8fa748..db0e3ab3b 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -36,7 +36,7 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.memory import MemoryBank from llama_stack.apis.safety import SafetyViolation -from llama_stack.apis.tools import UserDefinedToolDef +from llama_stack.apis.tools import ToolDef from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol @@ -157,7 +157,7 @@ class AgentConfigCommon(BaseModel): input_shields: Optional[List[str]] = Field(default_factory=list) output_shields: Optional[List[str]] = Field(default_factory=list) tools: Optional[List[AgentTool]] = Field(default_factory=list) - client_tools: Optional[List[UserDefinedToolDef]] = Field(default_factory=list) + client_tools: Optional[List[ToolDef]] = Field(default_factory=list) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) tool_prompt_format: Optional[ToolPromptFormat] = Field( default=ToolPromptFormat.json diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index 6585f3fd2..bc19a8a02 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -48,30 +48,16 @@ class Tool(Resource): @json_schema_type -class UserDefinedToolDef(BaseModel): - type: Literal["user_defined"] = "user_defined" +class ToolDef(BaseModel): name: str - description: str - parameters: List[ToolParameter] - metadata: Dict[str, Any] + description: Optional[str] = None + parameters: Optional[List[ToolParameter]] = None + metadata: Optional[Dict[str, Any]] = None tool_prompt_format: Optional[ToolPromptFormat] = Field( default=ToolPromptFormat.json ) -@json_schema_type -class BuiltInToolDef(BaseModel): - type: Literal["built_in"] = "built_in" - built_in_type: BuiltinTool - metadata: Optional[Dict[str, Any]] = None - - -ToolDef = register_schema( - Annotated[Union[UserDefinedToolDef, BuiltInToolDef], Field(discriminator="type")], - name="ToolDef", -) - - @json_schema_type class MCPToolGroupDef(BaseModel): """ @@ -100,7 +86,7 @@ ToolGroupDef = register_schema( @json_schema_type class ToolGroupInput(BaseModel): tool_group_id: str - tool_group: ToolGroupDef + tool_group_def: ToolGroupDef provider_id: Optional[str] = None @@ -127,7 +113,7 @@ class ToolGroups(Protocol): async def register_tool_group( self, tool_group_id: str, - tool_group: ToolGroupDef, + tool_group_def: ToolGroupDef, provider_id: Optional[str] = None, ) -> None: """Register a tool group""" diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index ccea470ae..b51de8fef 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -27,15 +27,12 @@ from llama_stack.apis.scoring_functions import ( ) from llama_stack.apis.shields import Shield, Shields from llama_stack.apis.tools import ( - BuiltInToolDef, MCPToolGroupDef, Tool, ToolGroup, ToolGroupDef, ToolGroups, ToolHost, - ToolPromptFormat, - UserDefinedToolDef, UserDefinedToolGroupDef, ) from llama_stack.distribution.datatypes import ( @@ -514,7 +511,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): async def register_tool_group( self, tool_group_id: str, - tool_group: ToolGroupDef, + tool_group_def: ToolGroupDef, provider_id: Optional[str] = None, ) -> None: tools = [] @@ -528,47 +525,31 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): provider_id = list(self.impls_by_provider_id.keys())[0] # parse tool group to the type if dict - tool_group = TypeAdapter(ToolGroupDef).validate_python(tool_group) - if isinstance(tool_group, MCPToolGroupDef): + tool_group_def = TypeAdapter(ToolGroupDef).validate_python(tool_group_def) + if isinstance(tool_group_def, MCPToolGroupDef): tool_defs = await self.impls_by_provider_id[provider_id].discover_tools( - tool_group + tool_group_def ) tool_host = ToolHost.model_context_protocol - elif isinstance(tool_group, UserDefinedToolGroupDef): - tool_defs = tool_group.tools + elif isinstance(tool_group_def, UserDefinedToolGroupDef): + tool_defs = tool_group_def.tools else: - raise ValueError(f"Unknown tool group: {tool_group}") + raise ValueError(f"Unknown tool group: {tool_group_def}") for tool_def in tool_defs: - if isinstance(tool_def, UserDefinedToolDef): - tools.append( - Tool( - identifier=tool_def.name, - tool_group=tool_group_id, - description=tool_def.description, - parameters=tool_def.parameters, - provider_id=provider_id, - tool_prompt_format=tool_def.tool_prompt_format, - provider_resource_id=tool_def.name, - metadata=tool_def.metadata, - tool_host=tool_host, - ) - ) - elif isinstance(tool_def, BuiltInToolDef): - tools.append( - Tool( - identifier=tool_def.built_in_type.value, - tool_group=tool_group_id, - built_in_type=tool_def.built_in_type, - description="", - parameters=[], - provider_id=provider_id, - tool_prompt_format=ToolPromptFormat.json, - provider_resource_id=tool_def.built_in_type.value, - metadata=tool_def.metadata, - tool_host=tool_host, - ) + tools.append( + Tool( + identifier=tool_def.name, + tool_group=tool_group_id, + description=tool_def.description or "", + parameters=tool_def.parameters or [], + provider_id=provider_id, + tool_prompt_format=tool_def.tool_prompt_format, + provider_resource_id=tool_def.name, + metadata=tool_def.metadata, + tool_host=tool_host, ) + ) for tool in tools: existing_tool = await self.get_tool(tool.identifier) # Compare existing and new object if one exists diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index e4ebb3011..cea4146e9 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -387,7 +387,7 @@ class ChatAgent(ShieldRunnerMixin): ) ) extra_args = tool_args.get("memory", {}) - args = { + tool_args = { # Query memory with the last message's content "query": input_messages[-1], **extra_args, @@ -396,8 +396,8 @@ class ChatAgent(ShieldRunnerMixin): session_info = await self.storage.get_session_info(session_id) # if the session has a memory bank id, let the memory tool use it if session_info.memory_bank_id: - args["memory_bank_id"] = session_info.memory_bank_id - serialized_args = tracing.serialize_value(args) + tool_args["memory_bank_id"] = session_info.memory_bank_id + serialized_args = tracing.serialize_value(tool_args) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( @@ -416,7 +416,7 @@ class ChatAgent(ShieldRunnerMixin): ) result = await self.tool_runtime_api.invoke_tool( tool_name="memory", - args=args, + args=tool_args, ) yield AgentTurnResponseStreamChunk( @@ -482,11 +482,7 @@ class ChatAgent(ShieldRunnerMixin): async for chunk in await self.inference_api.chat_completion( self.agent_config.model, input_messages, - tools=[ - tool - for tool in tool_defs.values() - if tool.tool_name != "memory" - ], + tools=[tool for tool in tool_defs.values()], tool_prompt_format=self.agent_config.tool_prompt_format, stream=True, sampling_params=sampling_params, @@ -728,10 +724,17 @@ class ChatAgent(ShieldRunnerMixin): continue tool_def = await self.tool_groups_api.get_tool(tool_name) + if tool_def is None: + raise ValueError(f"Tool {tool_name} not found") - if tool_def.built_in_type: - ret[tool_def.built_in_type] = ToolDefinition( - tool_name=tool_def.built_in_type + if tool_def.identifier.startswith("builtin::"): + built_in_type = tool_def.identifier[len("builtin::") :] + if built_in_type == "web_search": + built_in_type = "brave_search" + if built_in_type not in BuiltinTool.__members__: + raise ValueError(f"Unknown built-in tool: {built_in_type}") + ret[built_in_type] = ToolDefinition( + tool_name=BuiltinTool(built_in_type) ) continue @@ -759,52 +762,52 @@ class ChatAgent(ShieldRunnerMixin): tool_defs: Dict[str, ToolDefinition], ) -> None: memory_tool = tool_defs.get("memory", None) - code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None) - if documents: - content_items = [] - url_items = [] - pattern = re.compile("^(https?://|file://|data:)") - for d in documents: - if isinstance(d.content, URL): - url_items.append(d.content) - elif pattern.match(d.content): - url_items.append(URL(uri=d.content)) - else: - content_items.append(d) - - # Save the contents to a tempdir and use its path as a URL if code interpreter is present - if code_interpreter_tool: - for c in content_items: - temp_file_path = os.path.join( - self.tempdir, f"{make_random_string()}.txt" - ) - with open(temp_file_path, "w") as temp_file: - temp_file.write(c.content) - url_items.append(URL(uri=f"file://{temp_file_path}")) - - if memory_tool and code_interpreter_tool: - # if both memory and code_interpreter are available, we download the URLs - # and attach the data to the last message. - msg = await attachment_message(self.tempdir, url_items) - input_messages.append(msg) - # Since memory is present, add all the data to the memory bank - await self.add_to_session_memory_bank(session_id, documents) - elif code_interpreter_tool: - # if only code_interpreter is available, we download the URLs to a tempdir - # and attach the path to them as a message to inference with the - # assumption that the model invokes the code_interpreter tool with the path - msg = await attachment_message(self.tempdir, url_items) - input_messages.append(msg) - elif memory_tool: - # if only memory is available, we load the data from the URLs and content items to the memory bank - await self.add_to_session_memory_bank(session_id, documents) + code_interpreter_tool = tool_defs.get("code_interpreter", None) + content_items = [] + url_items = [] + pattern = re.compile("^(https?://|file://|data:)") + for d in documents: + if isinstance(d.content, URL): + url_items.append(d.content) + elif pattern.match(d.content): + url_items.append(URL(uri=d.content)) else: - # if no memory or code_interpreter tool is available, - # we try to load the data from the URLs and content items as a message to inference - # and add it to the last message's context - input_messages[-1].context = content_items + await load_data_from_urls( - url_items + content_items.append(d) + + # Save the contents to a tempdir and use its path as a URL if code interpreter is present + if code_interpreter_tool: + for c in content_items: + temp_file_path = os.path.join( + self.tempdir, f"{make_random_string()}.txt" ) + with open(temp_file_path, "w") as temp_file: + temp_file.write(c.content) + url_items.append(URL(uri=f"file://{temp_file_path}")) + + if memory_tool and code_interpreter_tool: + # if both memory and code_interpreter are available, we download the URLs + # and attach the data to the last message. + msg = await attachment_message(self.tempdir, url_items) + input_messages.append(msg) + # Since memory is present, add all the data to the memory bank + await self.add_to_session_memory_bank(session_id, documents) + elif code_interpreter_tool: + # if only code_interpreter is available, we download the URLs to a tempdir + # and attach the path to them as a message to inference with the + # assumption that the model invokes the code_interpreter tool with the path + msg = await attachment_message(self.tempdir, url_items) + input_messages.append(msg) + elif memory_tool: + # if only memory is available, we load the data from the URLs and content items to the memory bank + await self.add_to_session_memory_bank(session_id, documents) + else: + # if no memory or code_interpreter tool is available, + # we try to load the data from the URLs and content items as a message to inference + # and add it to the last message's context + input_messages[-1].context = "\n".join( + [doc.content for doc in content_items] + + await load_data_from_urls(url_items) + ) async def _ensure_memory_bank(self, session_id: str) -> str: session_info = await self.storage.get_session_info(session_id) @@ -909,7 +912,10 @@ async def execute_tool_call_maybe( tool_call = message.tool_calls[0] name = tool_call.tool_name if isinstance(name, BuiltinTool): - name = name.value + if name == BuiltinTool.brave_search: + name = "builtin::web_search" + else: + name = "builtin::" + name.value result = await tool_runtime_api.invoke_tool( tool_name=name, args=dict( diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py index 2e062d6d7..0fe0d0243 100644 --- a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py @@ -30,8 +30,7 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): pass async def register_tool(self, tool: Tool): - if tool.identifier != "code_interpreter": - raise ValueError(f"Tool identifier {tool.identifier} is not supported") + pass async def unregister_tool(self, tool_id: str) -> None: return diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index ca39ca03f..a760bb08a 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -17,7 +17,7 @@ from llama_stack_client.types.agent_create_params import AgentConfig from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument from llama_stack_client.types.memory_insert_params import Document from llama_stack_client.types.shared.completion_message import CompletionMessage -from llama_stack_client.types.tool_def_param import UserDefinedToolDefParameter +from llama_stack_client.types.tool_def_param import Parameter class TestClientTool(ClientTool): @@ -53,15 +53,15 @@ class TestClientTool(ClientTool): def get_description(self) -> str: return "Get the boiling point of imaginary liquids (eg. polyjuice)" - def get_params_definition(self) -> Dict[str, UserDefinedToolDefParameter]: + def get_params_definition(self) -> Dict[str, Parameter]: return { - "liquid_name": UserDefinedToolDefParameter( + "liquid_name": Parameter( name="liquid_name", parameter_type="string", description="The name of the liquid", required=True, ), - "celcius": UserDefinedToolDefParameter( + "celcius": Parameter( name="celcius", parameter_type="boolean", description="Whether to return the boiling point in Celcius", @@ -149,11 +149,11 @@ def test_agent_simple(llama_stack_client, agent_config): assert "I can't" in logs_str -def test_builtin_tool_brave_search(llama_stack_client, agent_config): +def test_builtin_tool_web_search(llama_stack_client, agent_config): agent_config = { **agent_config, "tools": [ - "brave_search", + "builtin::web_search", ], } agent = Agent(llama_stack_client, agent_config) @@ -182,7 +182,7 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config): agent_config = { **agent_config, "tools": [ - "code_interpreter", + "builtin::code_interpreter", ], } agent = Agent(llama_stack_client, agent_config) @@ -209,9 +209,9 @@ def test_code_execution(llama_stack_client): model="meta-llama/Llama-3.1-70B-Instruct", instructions="You are a helpful assistant", tools=[ - "code_interpreter", + "builtin::code_interpreter", ], - tool_choice="auto", + tool_choice="required", input_shields=[], output_shields=[], enable_session_persistence=False, @@ -242,7 +242,7 @@ def test_code_execution(llama_stack_client): ) logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs) - print(logs_str) + assert "Tool:code_interpreter" in logs_str def test_custom_tool(llama_stack_client, agent_config): @@ -250,7 +250,7 @@ def test_custom_tool(llama_stack_client, agent_config): agent_config = { **agent_config, "model": "meta-llama/Llama-3.2-3B-Instruct", - "tools": ["brave_search"], + "tools": ["builtin::web_search"], "client_tools": [client_tool.get_tool_definition()], "tool_prompt_format": "python_list", }