mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 02:22:25 +00:00
rename UserDefinedToolDef to ToolDef
This commit is contained in:
parent
db0b2a60c1
commit
e3775eb6f6
8 changed files with 180 additions and 322 deletions
|
|
@ -36,7 +36,7 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.apis.memory import MemoryBank
|
||||
from llama_stack.apis.safety import SafetyViolation
|
||||
from llama_stack.apis.tools import UserDefinedToolDef
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
|
||||
|
|
@ -157,7 +157,7 @@ class AgentConfigCommon(BaseModel):
|
|||
input_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
tools: Optional[List[AgentTool]] = Field(default_factory=list)
|
||||
client_tools: Optional[List[UserDefinedToolDef]] = Field(default_factory=list)
|
||||
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
|
||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
default=ToolPromptFormat.json
|
||||
|
|
|
|||
|
|
@ -48,30 +48,16 @@ class Tool(Resource):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class UserDefinedToolDef(BaseModel):
|
||||
type: Literal["user_defined"] = "user_defined"
|
||||
class ToolDef(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
parameters: List[ToolParameter]
|
||||
metadata: Dict[str, Any]
|
||||
description: Optional[str] = None
|
||||
parameters: Optional[List[ToolParameter]] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
default=ToolPromptFormat.json
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BuiltInToolDef(BaseModel):
|
||||
type: Literal["built_in"] = "built_in"
|
||||
built_in_type: BuiltinTool
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
ToolDef = register_schema(
|
||||
Annotated[Union[UserDefinedToolDef, BuiltInToolDef], Field(discriminator="type")],
|
||||
name="ToolDef",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MCPToolGroupDef(BaseModel):
|
||||
"""
|
||||
|
|
@ -100,7 +86,7 @@ ToolGroupDef = register_schema(
|
|||
@json_schema_type
|
||||
class ToolGroupInput(BaseModel):
|
||||
tool_group_id: str
|
||||
tool_group: ToolGroupDef
|
||||
tool_group_def: ToolGroupDef
|
||||
provider_id: Optional[str] = None
|
||||
|
||||
|
||||
|
|
@ -127,7 +113,7 @@ class ToolGroups(Protocol):
|
|||
async def register_tool_group(
|
||||
self,
|
||||
tool_group_id: str,
|
||||
tool_group: ToolGroupDef,
|
||||
tool_group_def: ToolGroupDef,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Register a tool group"""
|
||||
|
|
|
|||
|
|
@ -27,15 +27,12 @@ from llama_stack.apis.scoring_functions import (
|
|||
)
|
||||
from llama_stack.apis.shields import Shield, Shields
|
||||
from llama_stack.apis.tools import (
|
||||
BuiltInToolDef,
|
||||
MCPToolGroupDef,
|
||||
Tool,
|
||||
ToolGroup,
|
||||
ToolGroupDef,
|
||||
ToolGroups,
|
||||
ToolHost,
|
||||
ToolPromptFormat,
|
||||
UserDefinedToolDef,
|
||||
UserDefinedToolGroupDef,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import (
|
||||
|
|
@ -514,7 +511,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
async def register_tool_group(
|
||||
self,
|
||||
tool_group_id: str,
|
||||
tool_group: ToolGroupDef,
|
||||
tool_group_def: ToolGroupDef,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> None:
|
||||
tools = []
|
||||
|
|
@ -528,47 +525,31 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
|
||||
# parse tool group to the type if dict
|
||||
tool_group = TypeAdapter(ToolGroupDef).validate_python(tool_group)
|
||||
if isinstance(tool_group, MCPToolGroupDef):
|
||||
tool_group_def = TypeAdapter(ToolGroupDef).validate_python(tool_group_def)
|
||||
if isinstance(tool_group_def, MCPToolGroupDef):
|
||||
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
|
||||
tool_group
|
||||
tool_group_def
|
||||
)
|
||||
tool_host = ToolHost.model_context_protocol
|
||||
elif isinstance(tool_group, UserDefinedToolGroupDef):
|
||||
tool_defs = tool_group.tools
|
||||
elif isinstance(tool_group_def, UserDefinedToolGroupDef):
|
||||
tool_defs = tool_group_def.tools
|
||||
else:
|
||||
raise ValueError(f"Unknown tool group: {tool_group}")
|
||||
raise ValueError(f"Unknown tool group: {tool_group_def}")
|
||||
|
||||
for tool_def in tool_defs:
|
||||
if isinstance(tool_def, UserDefinedToolDef):
|
||||
tools.append(
|
||||
Tool(
|
||||
identifier=tool_def.name,
|
||||
tool_group=tool_group_id,
|
||||
description=tool_def.description,
|
||||
parameters=tool_def.parameters,
|
||||
provider_id=provider_id,
|
||||
tool_prompt_format=tool_def.tool_prompt_format,
|
||||
provider_resource_id=tool_def.name,
|
||||
metadata=tool_def.metadata,
|
||||
tool_host=tool_host,
|
||||
)
|
||||
)
|
||||
elif isinstance(tool_def, BuiltInToolDef):
|
||||
tools.append(
|
||||
Tool(
|
||||
identifier=tool_def.built_in_type.value,
|
||||
tool_group=tool_group_id,
|
||||
built_in_type=tool_def.built_in_type,
|
||||
description="",
|
||||
parameters=[],
|
||||
provider_id=provider_id,
|
||||
tool_prompt_format=ToolPromptFormat.json,
|
||||
provider_resource_id=tool_def.built_in_type.value,
|
||||
metadata=tool_def.metadata,
|
||||
tool_host=tool_host,
|
||||
)
|
||||
tools.append(
|
||||
Tool(
|
||||
identifier=tool_def.name,
|
||||
tool_group=tool_group_id,
|
||||
description=tool_def.description or "",
|
||||
parameters=tool_def.parameters or [],
|
||||
provider_id=provider_id,
|
||||
tool_prompt_format=tool_def.tool_prompt_format,
|
||||
provider_resource_id=tool_def.name,
|
||||
metadata=tool_def.metadata,
|
||||
tool_host=tool_host,
|
||||
)
|
||||
)
|
||||
for tool in tools:
|
||||
existing_tool = await self.get_tool(tool.identifier)
|
||||
# Compare existing and new object if one exists
|
||||
|
|
|
|||
|
|
@ -387,7 +387,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
extra_args = tool_args.get("memory", {})
|
||||
args = {
|
||||
tool_args = {
|
||||
# Query memory with the last message's content
|
||||
"query": input_messages[-1],
|
||||
**extra_args,
|
||||
|
|
@ -396,8 +396,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
session_info = await self.storage.get_session_info(session_id)
|
||||
# if the session has a memory bank id, let the memory tool use it
|
||||
if session_info.memory_bank_id:
|
||||
args["memory_bank_id"] = session_info.memory_bank_id
|
||||
serialized_args = tracing.serialize_value(args)
|
||||
tool_args["memory_bank_id"] = session_info.memory_bank_id
|
||||
serialized_args = tracing.serialize_value(tool_args)
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
|
|
@ -416,7 +416,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name="memory",
|
||||
args=args,
|
||||
args=tool_args,
|
||||
)
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
|
|
@ -482,11 +482,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async for chunk in await self.inference_api.chat_completion(
|
||||
self.agent_config.model,
|
||||
input_messages,
|
||||
tools=[
|
||||
tool
|
||||
for tool in tool_defs.values()
|
||||
if tool.tool_name != "memory"
|
||||
],
|
||||
tools=[tool for tool in tool_defs.values()],
|
||||
tool_prompt_format=self.agent_config.tool_prompt_format,
|
||||
stream=True,
|
||||
sampling_params=sampling_params,
|
||||
|
|
@ -728,10 +724,17 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
continue
|
||||
|
||||
tool_def = await self.tool_groups_api.get_tool(tool_name)
|
||||
if tool_def is None:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
|
||||
if tool_def.built_in_type:
|
||||
ret[tool_def.built_in_type] = ToolDefinition(
|
||||
tool_name=tool_def.built_in_type
|
||||
if tool_def.identifier.startswith("builtin::"):
|
||||
built_in_type = tool_def.identifier[len("builtin::") :]
|
||||
if built_in_type == "web_search":
|
||||
built_in_type = "brave_search"
|
||||
if built_in_type not in BuiltinTool.__members__:
|
||||
raise ValueError(f"Unknown built-in tool: {built_in_type}")
|
||||
ret[built_in_type] = ToolDefinition(
|
||||
tool_name=BuiltinTool(built_in_type)
|
||||
)
|
||||
continue
|
||||
|
||||
|
|
@ -759,52 +762,52 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_defs: Dict[str, ToolDefinition],
|
||||
) -> None:
|
||||
memory_tool = tool_defs.get("memory", None)
|
||||
code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None)
|
||||
if documents:
|
||||
content_items = []
|
||||
url_items = []
|
||||
pattern = re.compile("^(https?://|file://|data:)")
|
||||
for d in documents:
|
||||
if isinstance(d.content, URL):
|
||||
url_items.append(d.content)
|
||||
elif pattern.match(d.content):
|
||||
url_items.append(URL(uri=d.content))
|
||||
else:
|
||||
content_items.append(d)
|
||||
|
||||
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
|
||||
if code_interpreter_tool:
|
||||
for c in content_items:
|
||||
temp_file_path = os.path.join(
|
||||
self.tempdir, f"{make_random_string()}.txt"
|
||||
)
|
||||
with open(temp_file_path, "w") as temp_file:
|
||||
temp_file.write(c.content)
|
||||
url_items.append(URL(uri=f"file://{temp_file_path}"))
|
||||
|
||||
if memory_tool and code_interpreter_tool:
|
||||
# if both memory and code_interpreter are available, we download the URLs
|
||||
# and attach the data to the last message.
|
||||
msg = await attachment_message(self.tempdir, url_items)
|
||||
input_messages.append(msg)
|
||||
# Since memory is present, add all the data to the memory bank
|
||||
await self.add_to_session_memory_bank(session_id, documents)
|
||||
elif code_interpreter_tool:
|
||||
# if only code_interpreter is available, we download the URLs to a tempdir
|
||||
# and attach the path to them as a message to inference with the
|
||||
# assumption that the model invokes the code_interpreter tool with the path
|
||||
msg = await attachment_message(self.tempdir, url_items)
|
||||
input_messages.append(msg)
|
||||
elif memory_tool:
|
||||
# if only memory is available, we load the data from the URLs and content items to the memory bank
|
||||
await self.add_to_session_memory_bank(session_id, documents)
|
||||
code_interpreter_tool = tool_defs.get("code_interpreter", None)
|
||||
content_items = []
|
||||
url_items = []
|
||||
pattern = re.compile("^(https?://|file://|data:)")
|
||||
for d in documents:
|
||||
if isinstance(d.content, URL):
|
||||
url_items.append(d.content)
|
||||
elif pattern.match(d.content):
|
||||
url_items.append(URL(uri=d.content))
|
||||
else:
|
||||
# if no memory or code_interpreter tool is available,
|
||||
# we try to load the data from the URLs and content items as a message to inference
|
||||
# and add it to the last message's context
|
||||
input_messages[-1].context = content_items + await load_data_from_urls(
|
||||
url_items
|
||||
content_items.append(d)
|
||||
|
||||
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
|
||||
if code_interpreter_tool:
|
||||
for c in content_items:
|
||||
temp_file_path = os.path.join(
|
||||
self.tempdir, f"{make_random_string()}.txt"
|
||||
)
|
||||
with open(temp_file_path, "w") as temp_file:
|
||||
temp_file.write(c.content)
|
||||
url_items.append(URL(uri=f"file://{temp_file_path}"))
|
||||
|
||||
if memory_tool and code_interpreter_tool:
|
||||
# if both memory and code_interpreter are available, we download the URLs
|
||||
# and attach the data to the last message.
|
||||
msg = await attachment_message(self.tempdir, url_items)
|
||||
input_messages.append(msg)
|
||||
# Since memory is present, add all the data to the memory bank
|
||||
await self.add_to_session_memory_bank(session_id, documents)
|
||||
elif code_interpreter_tool:
|
||||
# if only code_interpreter is available, we download the URLs to a tempdir
|
||||
# and attach the path to them as a message to inference with the
|
||||
# assumption that the model invokes the code_interpreter tool with the path
|
||||
msg = await attachment_message(self.tempdir, url_items)
|
||||
input_messages.append(msg)
|
||||
elif memory_tool:
|
||||
# if only memory is available, we load the data from the URLs and content items to the memory bank
|
||||
await self.add_to_session_memory_bank(session_id, documents)
|
||||
else:
|
||||
# if no memory or code_interpreter tool is available,
|
||||
# we try to load the data from the URLs and content items as a message to inference
|
||||
# and add it to the last message's context
|
||||
input_messages[-1].context = "\n".join(
|
||||
[doc.content for doc in content_items]
|
||||
+ await load_data_from_urls(url_items)
|
||||
)
|
||||
|
||||
async def _ensure_memory_bank(self, session_id: str) -> str:
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
|
|
@ -909,7 +912,10 @@ async def execute_tool_call_maybe(
|
|||
tool_call = message.tool_calls[0]
|
||||
name = tool_call.tool_name
|
||||
if isinstance(name, BuiltinTool):
|
||||
name = name.value
|
||||
if name == BuiltinTool.brave_search:
|
||||
name = "builtin::web_search"
|
||||
else:
|
||||
name = "builtin::" + name.value
|
||||
result = await tool_runtime_api.invoke_tool(
|
||||
tool_name=name,
|
||||
args=dict(
|
||||
|
|
|
|||
|
|
@ -30,8 +30,7 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
if tool.identifier != "code_interpreter":
|
||||
raise ValueError(f"Tool identifier {tool.identifier} is not supported")
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
return
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue