diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html
index 5a78f5bae..fb7525988 100644
--- a/docs/resources/llama-stack-spec.html
+++ b/docs/resources/llama-stack-spec.html
@@ -3714,7 +3714,7 @@
"client_tools": {
"type": "array",
"items": {
- "$ref": "#/components/schemas/UserDefinedToolDef"
+ "$ref": "#/components/schemas/ToolDef"
}
},
"tool_choice": {
@@ -3792,60 +3792,9 @@
}
]
},
- "ToolParameter": {
+ "ToolDef": {
"type": "object",
"properties": {
- "name": {
- "type": "string"
- },
- "parameter_type": {
- "type": "string"
- },
- "description": {
- "type": "string"
- },
- "required": {
- "type": "boolean"
- },
- "default": {
- "oneOf": [
- {
- "type": "null"
- },
- {
- "type": "boolean"
- },
- {
- "type": "number"
- },
- {
- "type": "string"
- },
- {
- "type": "array"
- },
- {
- "type": "object"
- }
- ]
- }
- },
- "additionalProperties": false,
- "required": [
- "name",
- "parameter_type",
- "description",
- "required"
- ]
- },
- "UserDefinedToolDef": {
- "type": "object",
- "properties": {
- "type": {
- "type": "string",
- "const": "user_defined",
- "default": "user_defined"
- },
"name": {
"type": "string"
},
@@ -3890,11 +3839,53 @@
},
"additionalProperties": false,
"required": [
- "type",
+ "name"
+ ]
+ },
+ "ToolParameter": {
+ "type": "object",
+ "properties": {
+ "name": {
+ "type": "string"
+ },
+ "parameter_type": {
+ "type": "string"
+ },
+ "description": {
+ "type": "string"
+ },
+ "required": {
+ "type": "boolean"
+ },
+ "default": {
+ "oneOf": [
+ {
+ "type": "null"
+ },
+ {
+ "type": "boolean"
+ },
+ {
+ "type": "number"
+ },
+ {
+ "type": "string"
+ },
+ {
+ "type": "array"
+ },
+ {
+ "type": "object"
+ }
+ ]
+ }
+ },
+ "additionalProperties": false,
+ "required": [
"name",
+ "parameter_type",
"description",
- "parameters",
- "metadata"
+ "required"
]
},
"CreateAgentRequest": {
@@ -4589,49 +4580,6 @@
"session_id"
]
},
- "BuiltInToolDef": {
- "type": "object",
- "properties": {
- "type": {
- "type": "string",
- "const": "built_in",
- "default": "built_in"
- },
- "built_in_type": {
- "$ref": "#/components/schemas/BuiltinTool"
- },
- "metadata": {
- "type": "object",
- "additionalProperties": {
- "oneOf": [
- {
- "type": "null"
- },
- {
- "type": "boolean"
- },
- {
- "type": "number"
- },
- {
- "type": "string"
- },
- {
- "type": "array"
- },
- {
- "type": "object"
- }
- ]
- }
- }
- },
- "additionalProperties": false,
- "required": [
- "type",
- "built_in_type"
- ]
- },
"MCPToolGroupDef": {
"type": "object",
"properties": {
@@ -4651,16 +4599,6 @@
],
"title": "A tool group that is defined by in a model context protocol server. Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information."
},
- "ToolDef": {
- "oneOf": [
- {
- "$ref": "#/components/schemas/UserDefinedToolDef"
- },
- {
- "$ref": "#/components/schemas/BuiltInToolDef"
- }
- ]
- },
"ToolGroupDef": {
"oneOf": [
{
@@ -7436,7 +7374,7 @@
"tool_group_id": {
"type": "string"
},
- "tool_group": {
+ "tool_group_def": {
"$ref": "#/components/schemas/ToolGroupDef"
},
"provider_id": {
@@ -7446,7 +7384,7 @@
"additionalProperties": false,
"required": [
"tool_group_id",
- "tool_group"
+ "tool_group_def"
]
},
"RunEvalRequest": {
@@ -8098,10 +8036,6 @@
"name": "BenchmarkEvalTaskConfig",
"description": ""
},
- {
- "name": "BuiltInToolDef",
- "description": ""
- },
{
"name": "BuiltinTool",
"description": ""
@@ -8708,10 +8642,6 @@
"name": "UnstructuredLogEvent",
"description": ""
},
- {
- "name": "UserDefinedToolDef",
- "description": ""
- },
{
"name": "UserDefinedToolGroupDef",
"description": ""
@@ -8792,7 +8722,6 @@
"BatchCompletionRequest",
"BatchCompletionResponse",
"BenchmarkEvalTaskConfig",
- "BuiltInToolDef",
"BuiltinTool",
"CancelTrainingJobRequest",
"ChatCompletionRequest",
@@ -8931,7 +8860,6 @@
"UnregisterModelRequest",
"UnregisterToolGroupRequest",
"UnstructuredLogEvent",
- "UserDefinedToolDef",
"UserDefinedToolGroupDef",
"UserMessage",
"VectorMemoryBank",
diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml
index 72093b436..0937d8722 100644
--- a/docs/resources/llama-stack-spec.yaml
+++ b/docs/resources/llama-stack-spec.yaml
@@ -19,7 +19,7 @@ components:
properties:
client_tools:
items:
- $ref: '#/components/schemas/UserDefinedToolDef'
+ $ref: '#/components/schemas/ToolDef'
type: array
enable_session_persistence:
type: boolean
@@ -396,29 +396,6 @@ components:
- type
- eval_candidate
type: object
- BuiltInToolDef:
- additionalProperties: false
- properties:
- built_in_type:
- $ref: '#/components/schemas/BuiltinTool'
- metadata:
- additionalProperties:
- oneOf:
- - type: 'null'
- - type: boolean
- - type: number
- - type: string
- - type: array
- - type: object
- type: object
- type:
- const: built_in
- default: built_in
- type: string
- required:
- - type
- - built_in_type
- type: object
BuiltinTool:
enum:
- brave_search
@@ -1929,13 +1906,13 @@ components:
properties:
provider_id:
type: string
- tool_group:
+ tool_group_def:
$ref: '#/components/schemas/ToolGroupDef'
tool_group_id:
type: string
required:
- tool_group_id
- - tool_group
+ - tool_group_def
type: object
ResponseFormat:
oneOf:
@@ -2716,9 +2693,32 @@ components:
- required
type: string
ToolDef:
- oneOf:
- - $ref: '#/components/schemas/UserDefinedToolDef'
- - $ref: '#/components/schemas/BuiltInToolDef'
+ additionalProperties: false
+ properties:
+ description:
+ type: string
+ metadata:
+ additionalProperties:
+ oneOf:
+ - type: 'null'
+ - type: boolean
+ - type: number
+ - type: string
+ - type: array
+ - type: object
+ type: object
+ name:
+ type: string
+ parameters:
+ items:
+ $ref: '#/components/schemas/ToolParameter'
+ type: array
+ tool_prompt_format:
+ $ref: '#/components/schemas/ToolPromptFormat'
+ default: json
+ required:
+ - name
+ type: object
ToolDefinition:
additionalProperties: false
properties:
@@ -3087,41 +3087,6 @@ components:
- message
- severity
type: object
- UserDefinedToolDef:
- additionalProperties: false
- properties:
- description:
- type: string
- metadata:
- additionalProperties:
- oneOf:
- - type: 'null'
- - type: boolean
- - type: number
- - type: string
- - type: array
- - type: object
- type: object
- name:
- type: string
- parameters:
- items:
- $ref: '#/components/schemas/ToolParameter'
- type: array
- tool_prompt_format:
- $ref: '#/components/schemas/ToolPromptFormat'
- default: json
- type:
- const: user_defined
- default: user_defined
- type: string
- required:
- - type
- - name
- - description
- - parameters
- - metadata
- type: object
UserDefinedToolGroupDef:
additionalProperties: false
properties:
@@ -4823,8 +4788,6 @@ tags:
- description:
name: BenchmarkEvalTaskConfig
-- description:
- name: BuiltInToolDef
- description:
name: BuiltinTool
- description:
name: UnstructuredLogEvent
-- description:
- name: UserDefinedToolDef
- description:
name: UserDefinedToolGroupDef
@@ -5316,7 +5276,6 @@ x-tagGroups:
- BatchCompletionRequest
- BatchCompletionResponse
- BenchmarkEvalTaskConfig
- - BuiltInToolDef
- BuiltinTool
- CancelTrainingJobRequest
- ChatCompletionRequest
@@ -5455,7 +5414,6 @@ x-tagGroups:
- UnregisterModelRequest
- UnregisterToolGroupRequest
- UnstructuredLogEvent
- - UserDefinedToolDef
- UserDefinedToolGroupDef
- UserMessage
- VectorMemoryBank
diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py
index acf8fa748..db0e3ab3b 100644
--- a/llama_stack/apis/agents/agents.py
+++ b/llama_stack/apis/agents/agents.py
@@ -36,7 +36,7 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.memory import MemoryBank
from llama_stack.apis.safety import SafetyViolation
-from llama_stack.apis.tools import UserDefinedToolDef
+from llama_stack.apis.tools import ToolDef
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@@ -157,7 +157,7 @@ class AgentConfigCommon(BaseModel):
input_shields: Optional[List[str]] = Field(default_factory=list)
output_shields: Optional[List[str]] = Field(default_factory=list)
tools: Optional[List[AgentTool]] = Field(default_factory=list)
- client_tools: Optional[List[UserDefinedToolDef]] = Field(default_factory=list)
+ client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py
index 6585f3fd2..bc19a8a02 100644
--- a/llama_stack/apis/tools/tools.py
+++ b/llama_stack/apis/tools/tools.py
@@ -48,30 +48,16 @@ class Tool(Resource):
@json_schema_type
-class UserDefinedToolDef(BaseModel):
- type: Literal["user_defined"] = "user_defined"
+class ToolDef(BaseModel):
name: str
- description: str
- parameters: List[ToolParameter]
- metadata: Dict[str, Any]
+ description: Optional[str] = None
+ parameters: Optional[List[ToolParameter]] = None
+ metadata: Optional[Dict[str, Any]] = None
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
-@json_schema_type
-class BuiltInToolDef(BaseModel):
- type: Literal["built_in"] = "built_in"
- built_in_type: BuiltinTool
- metadata: Optional[Dict[str, Any]] = None
-
-
-ToolDef = register_schema(
- Annotated[Union[UserDefinedToolDef, BuiltInToolDef], Field(discriminator="type")],
- name="ToolDef",
-)
-
-
@json_schema_type
class MCPToolGroupDef(BaseModel):
"""
@@ -100,7 +86,7 @@ ToolGroupDef = register_schema(
@json_schema_type
class ToolGroupInput(BaseModel):
tool_group_id: str
- tool_group: ToolGroupDef
+ tool_group_def: ToolGroupDef
provider_id: Optional[str] = None
@@ -127,7 +113,7 @@ class ToolGroups(Protocol):
async def register_tool_group(
self,
tool_group_id: str,
- tool_group: ToolGroupDef,
+ tool_group_def: ToolGroupDef,
provider_id: Optional[str] = None,
) -> None:
"""Register a tool group"""
diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py
index ccea470ae..b51de8fef 100644
--- a/llama_stack/distribution/routers/routing_tables.py
+++ b/llama_stack/distribution/routers/routing_tables.py
@@ -27,15 +27,12 @@ from llama_stack.apis.scoring_functions import (
)
from llama_stack.apis.shields import Shield, Shields
from llama_stack.apis.tools import (
- BuiltInToolDef,
MCPToolGroupDef,
Tool,
ToolGroup,
ToolGroupDef,
ToolGroups,
ToolHost,
- ToolPromptFormat,
- UserDefinedToolDef,
UserDefinedToolGroupDef,
)
from llama_stack.distribution.datatypes import (
@@ -514,7 +511,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
async def register_tool_group(
self,
tool_group_id: str,
- tool_group: ToolGroupDef,
+ tool_group_def: ToolGroupDef,
provider_id: Optional[str] = None,
) -> None:
tools = []
@@ -528,47 +525,31 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
provider_id = list(self.impls_by_provider_id.keys())[0]
# parse tool group to the type if dict
- tool_group = TypeAdapter(ToolGroupDef).validate_python(tool_group)
- if isinstance(tool_group, MCPToolGroupDef):
+ tool_group_def = TypeAdapter(ToolGroupDef).validate_python(tool_group_def)
+ if isinstance(tool_group_def, MCPToolGroupDef):
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
- tool_group
+ tool_group_def
)
tool_host = ToolHost.model_context_protocol
- elif isinstance(tool_group, UserDefinedToolGroupDef):
- tool_defs = tool_group.tools
+ elif isinstance(tool_group_def, UserDefinedToolGroupDef):
+ tool_defs = tool_group_def.tools
else:
- raise ValueError(f"Unknown tool group: {tool_group}")
+ raise ValueError(f"Unknown tool group: {tool_group_def}")
for tool_def in tool_defs:
- if isinstance(tool_def, UserDefinedToolDef):
- tools.append(
- Tool(
- identifier=tool_def.name,
- tool_group=tool_group_id,
- description=tool_def.description,
- parameters=tool_def.parameters,
- provider_id=provider_id,
- tool_prompt_format=tool_def.tool_prompt_format,
- provider_resource_id=tool_def.name,
- metadata=tool_def.metadata,
- tool_host=tool_host,
- )
- )
- elif isinstance(tool_def, BuiltInToolDef):
- tools.append(
- Tool(
- identifier=tool_def.built_in_type.value,
- tool_group=tool_group_id,
- built_in_type=tool_def.built_in_type,
- description="",
- parameters=[],
- provider_id=provider_id,
- tool_prompt_format=ToolPromptFormat.json,
- provider_resource_id=tool_def.built_in_type.value,
- metadata=tool_def.metadata,
- tool_host=tool_host,
- )
+ tools.append(
+ Tool(
+ identifier=tool_def.name,
+ tool_group=tool_group_id,
+ description=tool_def.description or "",
+ parameters=tool_def.parameters or [],
+ provider_id=provider_id,
+ tool_prompt_format=tool_def.tool_prompt_format,
+ provider_resource_id=tool_def.name,
+ metadata=tool_def.metadata,
+ tool_host=tool_host,
)
+ )
for tool in tools:
existing_tool = await self.get_tool(tool.identifier)
# Compare existing and new object if one exists
diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
index e4ebb3011..cea4146e9 100644
--- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
+++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
@@ -387,7 +387,7 @@ class ChatAgent(ShieldRunnerMixin):
)
)
extra_args = tool_args.get("memory", {})
- args = {
+ tool_args = {
# Query memory with the last message's content
"query": input_messages[-1],
**extra_args,
@@ -396,8 +396,8 @@ class ChatAgent(ShieldRunnerMixin):
session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it
if session_info.memory_bank_id:
- args["memory_bank_id"] = session_info.memory_bank_id
- serialized_args = tracing.serialize_value(args)
+ tool_args["memory_bank_id"] = session_info.memory_bank_id
+ serialized_args = tracing.serialize_value(tool_args)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
@@ -416,7 +416,7 @@ class ChatAgent(ShieldRunnerMixin):
)
result = await self.tool_runtime_api.invoke_tool(
tool_name="memory",
- args=args,
+ args=tool_args,
)
yield AgentTurnResponseStreamChunk(
@@ -482,11 +482,7 @@ class ChatAgent(ShieldRunnerMixin):
async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
- tools=[
- tool
- for tool in tool_defs.values()
- if tool.tool_name != "memory"
- ],
+ tools=[tool for tool in tool_defs.values()],
tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True,
sampling_params=sampling_params,
@@ -728,10 +724,17 @@ class ChatAgent(ShieldRunnerMixin):
continue
tool_def = await self.tool_groups_api.get_tool(tool_name)
+ if tool_def is None:
+ raise ValueError(f"Tool {tool_name} not found")
- if tool_def.built_in_type:
- ret[tool_def.built_in_type] = ToolDefinition(
- tool_name=tool_def.built_in_type
+ if tool_def.identifier.startswith("builtin::"):
+ built_in_type = tool_def.identifier[len("builtin::") :]
+ if built_in_type == "web_search":
+ built_in_type = "brave_search"
+ if built_in_type not in BuiltinTool.__members__:
+ raise ValueError(f"Unknown built-in tool: {built_in_type}")
+ ret[built_in_type] = ToolDefinition(
+ tool_name=BuiltinTool(built_in_type)
)
continue
@@ -759,52 +762,52 @@ class ChatAgent(ShieldRunnerMixin):
tool_defs: Dict[str, ToolDefinition],
) -> None:
memory_tool = tool_defs.get("memory", None)
- code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None)
- if documents:
- content_items = []
- url_items = []
- pattern = re.compile("^(https?://|file://|data:)")
- for d in documents:
- if isinstance(d.content, URL):
- url_items.append(d.content)
- elif pattern.match(d.content):
- url_items.append(URL(uri=d.content))
- else:
- content_items.append(d)
-
- # Save the contents to a tempdir and use its path as a URL if code interpreter is present
- if code_interpreter_tool:
- for c in content_items:
- temp_file_path = os.path.join(
- self.tempdir, f"{make_random_string()}.txt"
- )
- with open(temp_file_path, "w") as temp_file:
- temp_file.write(c.content)
- url_items.append(URL(uri=f"file://{temp_file_path}"))
-
- if memory_tool and code_interpreter_tool:
- # if both memory and code_interpreter are available, we download the URLs
- # and attach the data to the last message.
- msg = await attachment_message(self.tempdir, url_items)
- input_messages.append(msg)
- # Since memory is present, add all the data to the memory bank
- await self.add_to_session_memory_bank(session_id, documents)
- elif code_interpreter_tool:
- # if only code_interpreter is available, we download the URLs to a tempdir
- # and attach the path to them as a message to inference with the
- # assumption that the model invokes the code_interpreter tool with the path
- msg = await attachment_message(self.tempdir, url_items)
- input_messages.append(msg)
- elif memory_tool:
- # if only memory is available, we load the data from the URLs and content items to the memory bank
- await self.add_to_session_memory_bank(session_id, documents)
+ code_interpreter_tool = tool_defs.get("code_interpreter", None)
+ content_items = []
+ url_items = []
+ pattern = re.compile("^(https?://|file://|data:)")
+ for d in documents:
+ if isinstance(d.content, URL):
+ url_items.append(d.content)
+ elif pattern.match(d.content):
+ url_items.append(URL(uri=d.content))
else:
- # if no memory or code_interpreter tool is available,
- # we try to load the data from the URLs and content items as a message to inference
- # and add it to the last message's context
- input_messages[-1].context = content_items + await load_data_from_urls(
- url_items
+ content_items.append(d)
+
+ # Save the contents to a tempdir and use its path as a URL if code interpreter is present
+ if code_interpreter_tool:
+ for c in content_items:
+ temp_file_path = os.path.join(
+ self.tempdir, f"{make_random_string()}.txt"
)
+ with open(temp_file_path, "w") as temp_file:
+ temp_file.write(c.content)
+ url_items.append(URL(uri=f"file://{temp_file_path}"))
+
+ if memory_tool and code_interpreter_tool:
+ # if both memory and code_interpreter are available, we download the URLs
+ # and attach the data to the last message.
+ msg = await attachment_message(self.tempdir, url_items)
+ input_messages.append(msg)
+ # Since memory is present, add all the data to the memory bank
+ await self.add_to_session_memory_bank(session_id, documents)
+ elif code_interpreter_tool:
+ # if only code_interpreter is available, we download the URLs to a tempdir
+ # and attach the path to them as a message to inference with the
+ # assumption that the model invokes the code_interpreter tool with the path
+ msg = await attachment_message(self.tempdir, url_items)
+ input_messages.append(msg)
+ elif memory_tool:
+ # if only memory is available, we load the data from the URLs and content items to the memory bank
+ await self.add_to_session_memory_bank(session_id, documents)
+ else:
+ # if no memory or code_interpreter tool is available,
+ # we try to load the data from the URLs and content items as a message to inference
+ # and add it to the last message's context
+ input_messages[-1].context = "\n".join(
+ [doc.content for doc in content_items]
+ + await load_data_from_urls(url_items)
+ )
async def _ensure_memory_bank(self, session_id: str) -> str:
session_info = await self.storage.get_session_info(session_id)
@@ -909,7 +912,10 @@ async def execute_tool_call_maybe(
tool_call = message.tool_calls[0]
name = tool_call.tool_name
if isinstance(name, BuiltinTool):
- name = name.value
+ if name == BuiltinTool.brave_search:
+ name = "builtin::web_search"
+ else:
+ name = "builtin::" + name.value
result = await tool_runtime_api.invoke_tool(
tool_name=name,
args=dict(
diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
index 2e062d6d7..0fe0d0243 100644
--- a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
+++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
@@ -30,8 +30,7 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
pass
async def register_tool(self, tool: Tool):
- if tool.identifier != "code_interpreter":
- raise ValueError(f"Tool identifier {tool.identifier} is not supported")
+ pass
async def unregister_tool(self, tool_id: str) -> None:
return
diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py
index ca39ca03f..a760bb08a 100644
--- a/tests/client-sdk/agents/test_agents.py
+++ b/tests/client-sdk/agents/test_agents.py
@@ -17,7 +17,7 @@ from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument
from llama_stack_client.types.memory_insert_params import Document
from llama_stack_client.types.shared.completion_message import CompletionMessage
-from llama_stack_client.types.tool_def_param import UserDefinedToolDefParameter
+from llama_stack_client.types.tool_def_param import Parameter
class TestClientTool(ClientTool):
@@ -53,15 +53,15 @@ class TestClientTool(ClientTool):
def get_description(self) -> str:
return "Get the boiling point of imaginary liquids (eg. polyjuice)"
- def get_params_definition(self) -> Dict[str, UserDefinedToolDefParameter]:
+ def get_params_definition(self) -> Dict[str, Parameter]:
return {
- "liquid_name": UserDefinedToolDefParameter(
+ "liquid_name": Parameter(
name="liquid_name",
parameter_type="string",
description="The name of the liquid",
required=True,
),
- "celcius": UserDefinedToolDefParameter(
+ "celcius": Parameter(
name="celcius",
parameter_type="boolean",
description="Whether to return the boiling point in Celcius",
@@ -149,11 +149,11 @@ def test_agent_simple(llama_stack_client, agent_config):
assert "I can't" in logs_str
-def test_builtin_tool_brave_search(llama_stack_client, agent_config):
+def test_builtin_tool_web_search(llama_stack_client, agent_config):
agent_config = {
**agent_config,
"tools": [
- "brave_search",
+ "builtin::web_search",
],
}
agent = Agent(llama_stack_client, agent_config)
@@ -182,7 +182,7 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config):
agent_config = {
**agent_config,
"tools": [
- "code_interpreter",
+ "builtin::code_interpreter",
],
}
agent = Agent(llama_stack_client, agent_config)
@@ -209,9 +209,9 @@ def test_code_execution(llama_stack_client):
model="meta-llama/Llama-3.1-70B-Instruct",
instructions="You are a helpful assistant",
tools=[
- "code_interpreter",
+ "builtin::code_interpreter",
],
- tool_choice="auto",
+ tool_choice="required",
input_shields=[],
output_shields=[],
enable_session_persistence=False,
@@ -242,7 +242,7 @@ def test_code_execution(llama_stack_client):
)
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
- print(logs_str)
+ assert "Tool:code_interpreter" in logs_str
def test_custom_tool(llama_stack_client, agent_config):
@@ -250,7 +250,7 @@ def test_custom_tool(llama_stack_client, agent_config):
agent_config = {
**agent_config,
"model": "meta-llama/Llama-3.2-3B-Instruct",
- "tools": ["brave_search"],
+ "tools": ["builtin::web_search"],
"client_tools": [client_tool.get_tool_definition()],
"tool_prompt_format": "python_list",
}