rename UserDefinedToolDef to ToolDef

This commit is contained in:
Dinesh Yeduguru 2025-01-07 09:14:26 -08:00
parent db0b2a60c1
commit e3775eb6f6
8 changed files with 180 additions and 322 deletions

View file

@ -3714,7 +3714,7 @@
"client_tools": { "client_tools": {
"type": "array", "type": "array",
"items": { "items": {
"$ref": "#/components/schemas/UserDefinedToolDef" "$ref": "#/components/schemas/ToolDef"
} }
}, },
"tool_choice": { "tool_choice": {
@ -3792,60 +3792,9 @@
} }
] ]
}, },
"ToolParameter": { "ToolDef": {
"type": "object", "type": "object",
"properties": { "properties": {
"name": {
"type": "string"
},
"parameter_type": {
"type": "string"
},
"description": {
"type": "string"
},
"required": {
"type": "boolean"
},
"default": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
},
"additionalProperties": false,
"required": [
"name",
"parameter_type",
"description",
"required"
]
},
"UserDefinedToolDef": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "user_defined",
"default": "user_defined"
},
"name": { "name": {
"type": "string" "type": "string"
}, },
@ -3890,11 +3839,53 @@
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"type", "name"
]
},
"ToolParameter": {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"parameter_type": {
"type": "string"
},
"description": {
"type": "string"
},
"required": {
"type": "boolean"
},
"default": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
},
"additionalProperties": false,
"required": [
"name", "name",
"parameter_type",
"description", "description",
"parameters", "required"
"metadata"
] ]
}, },
"CreateAgentRequest": { "CreateAgentRequest": {
@ -4589,49 +4580,6 @@
"session_id" "session_id"
] ]
}, },
"BuiltInToolDef": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "built_in",
"default": "built_in"
},
"built_in_type": {
"$ref": "#/components/schemas/BuiltinTool"
},
"metadata": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
"required": [
"type",
"built_in_type"
]
},
"MCPToolGroupDef": { "MCPToolGroupDef": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -4651,16 +4599,6 @@
], ],
"title": "A tool group that is defined by in a model context protocol server. Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information." "title": "A tool group that is defined by in a model context protocol server. Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information."
}, },
"ToolDef": {
"oneOf": [
{
"$ref": "#/components/schemas/UserDefinedToolDef"
},
{
"$ref": "#/components/schemas/BuiltInToolDef"
}
]
},
"ToolGroupDef": { "ToolGroupDef": {
"oneOf": [ "oneOf": [
{ {
@ -7436,7 +7374,7 @@
"tool_group_id": { "tool_group_id": {
"type": "string" "type": "string"
}, },
"tool_group": { "tool_group_def": {
"$ref": "#/components/schemas/ToolGroupDef" "$ref": "#/components/schemas/ToolGroupDef"
}, },
"provider_id": { "provider_id": {
@ -7446,7 +7384,7 @@
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"tool_group_id", "tool_group_id",
"tool_group" "tool_group_def"
] ]
}, },
"RunEvalRequest": { "RunEvalRequest": {
@ -8098,10 +8036,6 @@
"name": "BenchmarkEvalTaskConfig", "name": "BenchmarkEvalTaskConfig",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/BenchmarkEvalTaskConfig\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/BenchmarkEvalTaskConfig\" />"
}, },
{
"name": "BuiltInToolDef",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/BuiltInToolDef\" />"
},
{ {
"name": "BuiltinTool", "name": "BuiltinTool",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/BuiltinTool\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/BuiltinTool\" />"
@ -8708,10 +8642,6 @@
"name": "UnstructuredLogEvent", "name": "UnstructuredLogEvent",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnstructuredLogEvent\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnstructuredLogEvent\" />"
}, },
{
"name": "UserDefinedToolDef",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UserDefinedToolDef\" />"
},
{ {
"name": "UserDefinedToolGroupDef", "name": "UserDefinedToolGroupDef",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UserDefinedToolGroupDef\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/UserDefinedToolGroupDef\" />"
@ -8792,7 +8722,6 @@
"BatchCompletionRequest", "BatchCompletionRequest",
"BatchCompletionResponse", "BatchCompletionResponse",
"BenchmarkEvalTaskConfig", "BenchmarkEvalTaskConfig",
"BuiltInToolDef",
"BuiltinTool", "BuiltinTool",
"CancelTrainingJobRequest", "CancelTrainingJobRequest",
"ChatCompletionRequest", "ChatCompletionRequest",
@ -8931,7 +8860,6 @@
"UnregisterModelRequest", "UnregisterModelRequest",
"UnregisterToolGroupRequest", "UnregisterToolGroupRequest",
"UnstructuredLogEvent", "UnstructuredLogEvent",
"UserDefinedToolDef",
"UserDefinedToolGroupDef", "UserDefinedToolGroupDef",
"UserMessage", "UserMessage",
"VectorMemoryBank", "VectorMemoryBank",

View file

@ -19,7 +19,7 @@ components:
properties: properties:
client_tools: client_tools:
items: items:
$ref: '#/components/schemas/UserDefinedToolDef' $ref: '#/components/schemas/ToolDef'
type: array type: array
enable_session_persistence: enable_session_persistence:
type: boolean type: boolean
@ -396,29 +396,6 @@ components:
- type - type
- eval_candidate - eval_candidate
type: object type: object
BuiltInToolDef:
additionalProperties: false
properties:
built_in_type:
$ref: '#/components/schemas/BuiltinTool'
metadata:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
type:
const: built_in
default: built_in
type: string
required:
- type
- built_in_type
type: object
BuiltinTool: BuiltinTool:
enum: enum:
- brave_search - brave_search
@ -1929,13 +1906,13 @@ components:
properties: properties:
provider_id: provider_id:
type: string type: string
tool_group: tool_group_def:
$ref: '#/components/schemas/ToolGroupDef' $ref: '#/components/schemas/ToolGroupDef'
tool_group_id: tool_group_id:
type: string type: string
required: required:
- tool_group_id - tool_group_id
- tool_group - tool_group_def
type: object type: object
ResponseFormat: ResponseFormat:
oneOf: oneOf:
@ -2716,9 +2693,32 @@ components:
- required - required
type: string type: string
ToolDef: ToolDef:
oneOf: additionalProperties: false
- $ref: '#/components/schemas/UserDefinedToolDef' properties:
- $ref: '#/components/schemas/BuiltInToolDef' description:
type: string
metadata:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
name:
type: string
parameters:
items:
$ref: '#/components/schemas/ToolParameter'
type: array
tool_prompt_format:
$ref: '#/components/schemas/ToolPromptFormat'
default: json
required:
- name
type: object
ToolDefinition: ToolDefinition:
additionalProperties: false additionalProperties: false
properties: properties:
@ -3087,41 +3087,6 @@ components:
- message - message
- severity - severity
type: object type: object
UserDefinedToolDef:
additionalProperties: false
properties:
description:
type: string
metadata:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
name:
type: string
parameters:
items:
$ref: '#/components/schemas/ToolParameter'
type: array
tool_prompt_format:
$ref: '#/components/schemas/ToolPromptFormat'
default: json
type:
const: user_defined
default: user_defined
type: string
required:
- type
- name
- description
- parameters
- metadata
type: object
UserDefinedToolGroupDef: UserDefinedToolGroupDef:
additionalProperties: false additionalProperties: false
properties: properties:
@ -4823,8 +4788,6 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/BenchmarkEvalTaskConfig" - description: <SchemaDefinition schemaRef="#/components/schemas/BenchmarkEvalTaskConfig"
/> />
name: BenchmarkEvalTaskConfig name: BenchmarkEvalTaskConfig
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltInToolDef" />
name: BuiltInToolDef
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" /> - description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
name: BuiltinTool name: BuiltinTool
- description: <SchemaDefinition schemaRef="#/components/schemas/CancelTrainingJobRequest" - description: <SchemaDefinition schemaRef="#/components/schemas/CancelTrainingJobRequest"
@ -5251,9 +5214,6 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/UnstructuredLogEvent" - description: <SchemaDefinition schemaRef="#/components/schemas/UnstructuredLogEvent"
/> />
name: UnstructuredLogEvent name: UnstructuredLogEvent
- description: <SchemaDefinition schemaRef="#/components/schemas/UserDefinedToolDef"
/>
name: UserDefinedToolDef
- description: <SchemaDefinition schemaRef="#/components/schemas/UserDefinedToolGroupDef" - description: <SchemaDefinition schemaRef="#/components/schemas/UserDefinedToolGroupDef"
/> />
name: UserDefinedToolGroupDef name: UserDefinedToolGroupDef
@ -5316,7 +5276,6 @@ x-tagGroups:
- BatchCompletionRequest - BatchCompletionRequest
- BatchCompletionResponse - BatchCompletionResponse
- BenchmarkEvalTaskConfig - BenchmarkEvalTaskConfig
- BuiltInToolDef
- BuiltinTool - BuiltinTool
- CancelTrainingJobRequest - CancelTrainingJobRequest
- ChatCompletionRequest - ChatCompletionRequest
@ -5455,7 +5414,6 @@ x-tagGroups:
- UnregisterModelRequest - UnregisterModelRequest
- UnregisterToolGroupRequest - UnregisterToolGroupRequest
- UnstructuredLogEvent - UnstructuredLogEvent
- UserDefinedToolDef
- UserDefinedToolGroupDef - UserDefinedToolGroupDef
- UserMessage - UserMessage
- VectorMemoryBank - VectorMemoryBank

View file

@ -36,7 +36,7 @@ from llama_stack.apis.inference import (
) )
from llama_stack.apis.memory import MemoryBank from llama_stack.apis.memory import MemoryBank
from llama_stack.apis.safety import SafetyViolation from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import UserDefinedToolDef from llama_stack.apis.tools import ToolDef
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@ -157,7 +157,7 @@ class AgentConfigCommon(BaseModel):
input_shields: Optional[List[str]] = Field(default_factory=list) input_shields: Optional[List[str]] = Field(default_factory=list)
output_shields: Optional[List[str]] = Field(default_factory=list) output_shields: Optional[List[str]] = Field(default_factory=list)
tools: Optional[List[AgentTool]] = Field(default_factory=list) tools: Optional[List[AgentTool]] = Field(default_factory=list)
client_tools: Optional[List[UserDefinedToolDef]] = Field(default_factory=list) client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field( tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json default=ToolPromptFormat.json

View file

@ -48,30 +48,16 @@ class Tool(Resource):
@json_schema_type @json_schema_type
class UserDefinedToolDef(BaseModel): class ToolDef(BaseModel):
type: Literal["user_defined"] = "user_defined"
name: str name: str
description: str description: Optional[str] = None
parameters: List[ToolParameter] parameters: Optional[List[ToolParameter]] = None
metadata: Dict[str, Any] metadata: Optional[Dict[str, Any]] = None
tool_prompt_format: Optional[ToolPromptFormat] = Field( tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json default=ToolPromptFormat.json
) )
@json_schema_type
class BuiltInToolDef(BaseModel):
type: Literal["built_in"] = "built_in"
built_in_type: BuiltinTool
metadata: Optional[Dict[str, Any]] = None
ToolDef = register_schema(
Annotated[Union[UserDefinedToolDef, BuiltInToolDef], Field(discriminator="type")],
name="ToolDef",
)
@json_schema_type @json_schema_type
class MCPToolGroupDef(BaseModel): class MCPToolGroupDef(BaseModel):
""" """
@ -100,7 +86,7 @@ ToolGroupDef = register_schema(
@json_schema_type @json_schema_type
class ToolGroupInput(BaseModel): class ToolGroupInput(BaseModel):
tool_group_id: str tool_group_id: str
tool_group: ToolGroupDef tool_group_def: ToolGroupDef
provider_id: Optional[str] = None provider_id: Optional[str] = None
@ -127,7 +113,7 @@ class ToolGroups(Protocol):
async def register_tool_group( async def register_tool_group(
self, self,
tool_group_id: str, tool_group_id: str,
tool_group: ToolGroupDef, tool_group_def: ToolGroupDef,
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
) -> None: ) -> None:
"""Register a tool group""" """Register a tool group"""

View file

@ -27,15 +27,12 @@ from llama_stack.apis.scoring_functions import (
) )
from llama_stack.apis.shields import Shield, Shields from llama_stack.apis.shields import Shield, Shields
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
BuiltInToolDef,
MCPToolGroupDef, MCPToolGroupDef,
Tool, Tool,
ToolGroup, ToolGroup,
ToolGroupDef, ToolGroupDef,
ToolGroups, ToolGroups,
ToolHost, ToolHost,
ToolPromptFormat,
UserDefinedToolDef,
UserDefinedToolGroupDef, UserDefinedToolGroupDef,
) )
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
@ -514,7 +511,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
async def register_tool_group( async def register_tool_group(
self, self,
tool_group_id: str, tool_group_id: str,
tool_group: ToolGroupDef, tool_group_def: ToolGroupDef,
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
) -> None: ) -> None:
tools = [] tools = []
@ -528,47 +525,31 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
provider_id = list(self.impls_by_provider_id.keys())[0] provider_id = list(self.impls_by_provider_id.keys())[0]
# parse tool group to the type if dict # parse tool group to the type if dict
tool_group = TypeAdapter(ToolGroupDef).validate_python(tool_group) tool_group_def = TypeAdapter(ToolGroupDef).validate_python(tool_group_def)
if isinstance(tool_group, MCPToolGroupDef): if isinstance(tool_group_def, MCPToolGroupDef):
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools( tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
tool_group tool_group_def
) )
tool_host = ToolHost.model_context_protocol tool_host = ToolHost.model_context_protocol
elif isinstance(tool_group, UserDefinedToolGroupDef): elif isinstance(tool_group_def, UserDefinedToolGroupDef):
tool_defs = tool_group.tools tool_defs = tool_group_def.tools
else: else:
raise ValueError(f"Unknown tool group: {tool_group}") raise ValueError(f"Unknown tool group: {tool_group_def}")
for tool_def in tool_defs: for tool_def in tool_defs:
if isinstance(tool_def, UserDefinedToolDef): tools.append(
tools.append( Tool(
Tool( identifier=tool_def.name,
identifier=tool_def.name, tool_group=tool_group_id,
tool_group=tool_group_id, description=tool_def.description or "",
description=tool_def.description, parameters=tool_def.parameters or [],
parameters=tool_def.parameters, provider_id=provider_id,
provider_id=provider_id, tool_prompt_format=tool_def.tool_prompt_format,
tool_prompt_format=tool_def.tool_prompt_format, provider_resource_id=tool_def.name,
provider_resource_id=tool_def.name, metadata=tool_def.metadata,
metadata=tool_def.metadata, tool_host=tool_host,
tool_host=tool_host,
)
)
elif isinstance(tool_def, BuiltInToolDef):
tools.append(
Tool(
identifier=tool_def.built_in_type.value,
tool_group=tool_group_id,
built_in_type=tool_def.built_in_type,
description="",
parameters=[],
provider_id=provider_id,
tool_prompt_format=ToolPromptFormat.json,
provider_resource_id=tool_def.built_in_type.value,
metadata=tool_def.metadata,
tool_host=tool_host,
)
) )
)
for tool in tools: for tool in tools:
existing_tool = await self.get_tool(tool.identifier) existing_tool = await self.get_tool(tool.identifier)
# Compare existing and new object if one exists # Compare existing and new object if one exists

View file

@ -387,7 +387,7 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
extra_args = tool_args.get("memory", {}) extra_args = tool_args.get("memory", {})
args = { tool_args = {
# Query memory with the last message's content # Query memory with the last message's content
"query": input_messages[-1], "query": input_messages[-1],
**extra_args, **extra_args,
@ -396,8 +396,8 @@ class ChatAgent(ShieldRunnerMixin):
session_info = await self.storage.get_session_info(session_id) session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it # if the session has a memory bank id, let the memory tool use it
if session_info.memory_bank_id: if session_info.memory_bank_id:
args["memory_bank_id"] = session_info.memory_bank_id tool_args["memory_bank_id"] = session_info.memory_bank_id
serialized_args = tracing.serialize_value(args) serialized_args = tracing.serialize_value(tool_args)
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload( payload=AgentTurnResponseStepProgressPayload(
@ -416,7 +416,7 @@ class ChatAgent(ShieldRunnerMixin):
) )
result = await self.tool_runtime_api.invoke_tool( result = await self.tool_runtime_api.invoke_tool(
tool_name="memory", tool_name="memory",
args=args, args=tool_args,
) )
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
@ -482,11 +482,7 @@ class ChatAgent(ShieldRunnerMixin):
async for chunk in await self.inference_api.chat_completion( async for chunk in await self.inference_api.chat_completion(
self.agent_config.model, self.agent_config.model,
input_messages, input_messages,
tools=[ tools=[tool for tool in tool_defs.values()],
tool
for tool in tool_defs.values()
if tool.tool_name != "memory"
],
tool_prompt_format=self.agent_config.tool_prompt_format, tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True, stream=True,
sampling_params=sampling_params, sampling_params=sampling_params,
@ -728,10 +724,17 @@ class ChatAgent(ShieldRunnerMixin):
continue continue
tool_def = await self.tool_groups_api.get_tool(tool_name) tool_def = await self.tool_groups_api.get_tool(tool_name)
if tool_def is None:
raise ValueError(f"Tool {tool_name} not found")
if tool_def.built_in_type: if tool_def.identifier.startswith("builtin::"):
ret[tool_def.built_in_type] = ToolDefinition( built_in_type = tool_def.identifier[len("builtin::") :]
tool_name=tool_def.built_in_type if built_in_type == "web_search":
built_in_type = "brave_search"
if built_in_type not in BuiltinTool.__members__:
raise ValueError(f"Unknown built-in tool: {built_in_type}")
ret[built_in_type] = ToolDefinition(
tool_name=BuiltinTool(built_in_type)
) )
continue continue
@ -759,52 +762,52 @@ class ChatAgent(ShieldRunnerMixin):
tool_defs: Dict[str, ToolDefinition], tool_defs: Dict[str, ToolDefinition],
) -> None: ) -> None:
memory_tool = tool_defs.get("memory", None) memory_tool = tool_defs.get("memory", None)
code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None) code_interpreter_tool = tool_defs.get("code_interpreter", None)
if documents: content_items = []
content_items = [] url_items = []
url_items = [] pattern = re.compile("^(https?://|file://|data:)")
pattern = re.compile("^(https?://|file://|data:)") for d in documents:
for d in documents: if isinstance(d.content, URL):
if isinstance(d.content, URL): url_items.append(d.content)
url_items.append(d.content) elif pattern.match(d.content):
elif pattern.match(d.content): url_items.append(URL(uri=d.content))
url_items.append(URL(uri=d.content))
else:
content_items.append(d)
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
if code_interpreter_tool:
for c in content_items:
temp_file_path = os.path.join(
self.tempdir, f"{make_random_string()}.txt"
)
with open(temp_file_path, "w") as temp_file:
temp_file.write(c.content)
url_items.append(URL(uri=f"file://{temp_file_path}"))
if memory_tool and code_interpreter_tool:
# if both memory and code_interpreter are available, we download the URLs
# and attach the data to the last message.
msg = await attachment_message(self.tempdir, url_items)
input_messages.append(msg)
# Since memory is present, add all the data to the memory bank
await self.add_to_session_memory_bank(session_id, documents)
elif code_interpreter_tool:
# if only code_interpreter is available, we download the URLs to a tempdir
# and attach the path to them as a message to inference with the
# assumption that the model invokes the code_interpreter tool with the path
msg = await attachment_message(self.tempdir, url_items)
input_messages.append(msg)
elif memory_tool:
# if only memory is available, we load the data from the URLs and content items to the memory bank
await self.add_to_session_memory_bank(session_id, documents)
else: else:
# if no memory or code_interpreter tool is available, content_items.append(d)
# we try to load the data from the URLs and content items as a message to inference
# and add it to the last message's context # Save the contents to a tempdir and use its path as a URL if code interpreter is present
input_messages[-1].context = content_items + await load_data_from_urls( if code_interpreter_tool:
url_items for c in content_items:
temp_file_path = os.path.join(
self.tempdir, f"{make_random_string()}.txt"
) )
with open(temp_file_path, "w") as temp_file:
temp_file.write(c.content)
url_items.append(URL(uri=f"file://{temp_file_path}"))
if memory_tool and code_interpreter_tool:
# if both memory and code_interpreter are available, we download the URLs
# and attach the data to the last message.
msg = await attachment_message(self.tempdir, url_items)
input_messages.append(msg)
# Since memory is present, add all the data to the memory bank
await self.add_to_session_memory_bank(session_id, documents)
elif code_interpreter_tool:
# if only code_interpreter is available, we download the URLs to a tempdir
# and attach the path to them as a message to inference with the
# assumption that the model invokes the code_interpreter tool with the path
msg = await attachment_message(self.tempdir, url_items)
input_messages.append(msg)
elif memory_tool:
# if only memory is available, we load the data from the URLs and content items to the memory bank
await self.add_to_session_memory_bank(session_id, documents)
else:
# if no memory or code_interpreter tool is available,
# we try to load the data from the URLs and content items as a message to inference
# and add it to the last message's context
input_messages[-1].context = "\n".join(
[doc.content for doc in content_items]
+ await load_data_from_urls(url_items)
)
async def _ensure_memory_bank(self, session_id: str) -> str: async def _ensure_memory_bank(self, session_id: str) -> str:
session_info = await self.storage.get_session_info(session_id) session_info = await self.storage.get_session_info(session_id)
@ -909,7 +912,10 @@ async def execute_tool_call_maybe(
tool_call = message.tool_calls[0] tool_call = message.tool_calls[0]
name = tool_call.tool_name name = tool_call.tool_name
if isinstance(name, BuiltinTool): if isinstance(name, BuiltinTool):
name = name.value if name == BuiltinTool.brave_search:
name = "builtin::web_search"
else:
name = "builtin::" + name.value
result = await tool_runtime_api.invoke_tool( result = await tool_runtime_api.invoke_tool(
tool_name=name, tool_name=name,
args=dict( args=dict(

View file

@ -30,8 +30,7 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
pass pass
async def register_tool(self, tool: Tool): async def register_tool(self, tool: Tool):
if tool.identifier != "code_interpreter": pass
raise ValueError(f"Tool identifier {tool.identifier} is not supported")
async def unregister_tool(self, tool_id: str) -> None: async def unregister_tool(self, tool_id: str) -> None:
return return

View file

@ -17,7 +17,7 @@ from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument
from llama_stack_client.types.memory_insert_params import Document from llama_stack_client.types.memory_insert_params import Document
from llama_stack_client.types.shared.completion_message import CompletionMessage from llama_stack_client.types.shared.completion_message import CompletionMessage
from llama_stack_client.types.tool_def_param import UserDefinedToolDefParameter from llama_stack_client.types.tool_def_param import Parameter
class TestClientTool(ClientTool): class TestClientTool(ClientTool):
@ -53,15 +53,15 @@ class TestClientTool(ClientTool):
def get_description(self) -> str: def get_description(self) -> str:
return "Get the boiling point of imaginary liquids (eg. polyjuice)" return "Get the boiling point of imaginary liquids (eg. polyjuice)"
def get_params_definition(self) -> Dict[str, UserDefinedToolDefParameter]: def get_params_definition(self) -> Dict[str, Parameter]:
return { return {
"liquid_name": UserDefinedToolDefParameter( "liquid_name": Parameter(
name="liquid_name", name="liquid_name",
parameter_type="string", parameter_type="string",
description="The name of the liquid", description="The name of the liquid",
required=True, required=True,
), ),
"celcius": UserDefinedToolDefParameter( "celcius": Parameter(
name="celcius", name="celcius",
parameter_type="boolean", parameter_type="boolean",
description="Whether to return the boiling point in Celcius", description="Whether to return the boiling point in Celcius",
@ -149,11 +149,11 @@ def test_agent_simple(llama_stack_client, agent_config):
assert "I can't" in logs_str assert "I can't" in logs_str
def test_builtin_tool_brave_search(llama_stack_client, agent_config): def test_builtin_tool_web_search(llama_stack_client, agent_config):
agent_config = { agent_config = {
**agent_config, **agent_config,
"tools": [ "tools": [
"brave_search", "builtin::web_search",
], ],
} }
agent = Agent(llama_stack_client, agent_config) agent = Agent(llama_stack_client, agent_config)
@ -182,7 +182,7 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config):
agent_config = { agent_config = {
**agent_config, **agent_config,
"tools": [ "tools": [
"code_interpreter", "builtin::code_interpreter",
], ],
} }
agent = Agent(llama_stack_client, agent_config) agent = Agent(llama_stack_client, agent_config)
@ -209,9 +209,9 @@ def test_code_execution(llama_stack_client):
model="meta-llama/Llama-3.1-70B-Instruct", model="meta-llama/Llama-3.1-70B-Instruct",
instructions="You are a helpful assistant", instructions="You are a helpful assistant",
tools=[ tools=[
"code_interpreter", "builtin::code_interpreter",
], ],
tool_choice="auto", tool_choice="required",
input_shields=[], input_shields=[],
output_shields=[], output_shields=[],
enable_session_persistence=False, enable_session_persistence=False,
@ -242,7 +242,7 @@ def test_code_execution(llama_stack_client):
) )
logs = [str(log) for log in EventLogger().log(response) if log is not None] logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs) logs_str = "".join(logs)
print(logs_str) assert "Tool:code_interpreter" in logs_str
def test_custom_tool(llama_stack_client, agent_config): def test_custom_tool(llama_stack_client, agent_config):
@ -250,7 +250,7 @@ def test_custom_tool(llama_stack_client, agent_config):
agent_config = { agent_config = {
**agent_config, **agent_config,
"model": "meta-llama/Llama-3.2-3B-Instruct", "model": "meta-llama/Llama-3.2-3B-Instruct",
"tools": ["brave_search"], "tools": ["builtin::web_search"],
"client_tools": [client_tool.get_tool_definition()], "client_tools": [client_tool.get_tool_definition()],
"tool_prompt_format": "python_list", "tool_prompt_format": "python_list",
} }