rename UserDefinedToolDef to ToolDef

This commit is contained in:
Dinesh Yeduguru 2025-01-07 09:14:26 -08:00
parent db0b2a60c1
commit e3775eb6f6
8 changed files with 180 additions and 322 deletions

View file

@ -387,7 +387,7 @@ class ChatAgent(ShieldRunnerMixin):
)
)
extra_args = tool_args.get("memory", {})
args = {
tool_args = {
# Query memory with the last message's content
"query": input_messages[-1],
**extra_args,
@ -396,8 +396,8 @@ class ChatAgent(ShieldRunnerMixin):
session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it
if session_info.memory_bank_id:
args["memory_bank_id"] = session_info.memory_bank_id
serialized_args = tracing.serialize_value(args)
tool_args["memory_bank_id"] = session_info.memory_bank_id
serialized_args = tracing.serialize_value(tool_args)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
@ -416,7 +416,7 @@ class ChatAgent(ShieldRunnerMixin):
)
result = await self.tool_runtime_api.invoke_tool(
tool_name="memory",
args=args,
args=tool_args,
)
yield AgentTurnResponseStreamChunk(
@ -482,11 +482,7 @@ class ChatAgent(ShieldRunnerMixin):
async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=[
tool
for tool in tool_defs.values()
if tool.tool_name != "memory"
],
tools=[tool for tool in tool_defs.values()],
tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True,
sampling_params=sampling_params,
@ -728,10 +724,17 @@ class ChatAgent(ShieldRunnerMixin):
continue
tool_def = await self.tool_groups_api.get_tool(tool_name)
if tool_def is None:
raise ValueError(f"Tool {tool_name} not found")
if tool_def.built_in_type:
ret[tool_def.built_in_type] = ToolDefinition(
tool_name=tool_def.built_in_type
if tool_def.identifier.startswith("builtin::"):
built_in_type = tool_def.identifier[len("builtin::") :]
if built_in_type == "web_search":
built_in_type = "brave_search"
if built_in_type not in BuiltinTool.__members__:
raise ValueError(f"Unknown built-in tool: {built_in_type}")
ret[built_in_type] = ToolDefinition(
tool_name=BuiltinTool(built_in_type)
)
continue
@ -759,52 +762,52 @@ class ChatAgent(ShieldRunnerMixin):
tool_defs: Dict[str, ToolDefinition],
) -> None:
memory_tool = tool_defs.get("memory", None)
code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None)
if documents:
content_items = []
url_items = []
pattern = re.compile("^(https?://|file://|data:)")
for d in documents:
if isinstance(d.content, URL):
url_items.append(d.content)
elif pattern.match(d.content):
url_items.append(URL(uri=d.content))
else:
content_items.append(d)
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
if code_interpreter_tool:
for c in content_items:
temp_file_path = os.path.join(
self.tempdir, f"{make_random_string()}.txt"
)
with open(temp_file_path, "w") as temp_file:
temp_file.write(c.content)
url_items.append(URL(uri=f"file://{temp_file_path}"))
if memory_tool and code_interpreter_tool:
# if both memory and code_interpreter are available, we download the URLs
# and attach the data to the last message.
msg = await attachment_message(self.tempdir, url_items)
input_messages.append(msg)
# Since memory is present, add all the data to the memory bank
await self.add_to_session_memory_bank(session_id, documents)
elif code_interpreter_tool:
# if only code_interpreter is available, we download the URLs to a tempdir
# and attach the path to them as a message to inference with the
# assumption that the model invokes the code_interpreter tool with the path
msg = await attachment_message(self.tempdir, url_items)
input_messages.append(msg)
elif memory_tool:
# if only memory is available, we load the data from the URLs and content items to the memory bank
await self.add_to_session_memory_bank(session_id, documents)
code_interpreter_tool = tool_defs.get("code_interpreter", None)
content_items = []
url_items = []
pattern = re.compile("^(https?://|file://|data:)")
for d in documents:
if isinstance(d.content, URL):
url_items.append(d.content)
elif pattern.match(d.content):
url_items.append(URL(uri=d.content))
else:
# if no memory or code_interpreter tool is available,
# we try to load the data from the URLs and content items as a message to inference
# and add it to the last message's context
input_messages[-1].context = content_items + await load_data_from_urls(
url_items
content_items.append(d)
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
if code_interpreter_tool:
for c in content_items:
temp_file_path = os.path.join(
self.tempdir, f"{make_random_string()}.txt"
)
with open(temp_file_path, "w") as temp_file:
temp_file.write(c.content)
url_items.append(URL(uri=f"file://{temp_file_path}"))
if memory_tool and code_interpreter_tool:
# if both memory and code_interpreter are available, we download the URLs
# and attach the data to the last message.
msg = await attachment_message(self.tempdir, url_items)
input_messages.append(msg)
# Since memory is present, add all the data to the memory bank
await self.add_to_session_memory_bank(session_id, documents)
elif code_interpreter_tool:
# if only code_interpreter is available, we download the URLs to a tempdir
# and attach the path to them as a message to inference with the
# assumption that the model invokes the code_interpreter tool with the path
msg = await attachment_message(self.tempdir, url_items)
input_messages.append(msg)
elif memory_tool:
# if only memory is available, we load the data from the URLs and content items to the memory bank
await self.add_to_session_memory_bank(session_id, documents)
else:
# if no memory or code_interpreter tool is available,
# we try to load the data from the URLs and content items as a message to inference
# and add it to the last message's context
input_messages[-1].context = "\n".join(
[doc.content for doc in content_items]
+ await load_data_from_urls(url_items)
)
async def _ensure_memory_bank(self, session_id: str) -> str:
session_info = await self.storage.get_session_info(session_id)
@ -909,7 +912,10 @@ async def execute_tool_call_maybe(
tool_call = message.tool_calls[0]
name = tool_call.tool_name
if isinstance(name, BuiltinTool):
name = name.value
if name == BuiltinTool.brave_search:
name = "builtin::web_search"
else:
name = "builtin::" + name.value
result = await tool_runtime_api.invoke_tool(
tool_name=name,
args=dict(

View file

@ -30,8 +30,7 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
pass
async def register_tool(self, tool: Tool):
if tool.identifier != "code_interpreter":
raise ValueError(f"Tool identifier {tool.identifier} is not supported")
pass
async def unregister_tool(self, tool_id: str) -> None:
return