mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
more substantial cleanup of Tool vs. ToolDef crap
This commit is contained in:
parent
fa6ed28aea
commit
6749c853c0
34 changed files with 2676 additions and 615 deletions
|
@ -19,48 +19,6 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
|||
from .rag_tool import RAGToolRuntime
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolParameter(BaseModel):
|
||||
"""Parameter definition for a tool.
|
||||
|
||||
:param name: Name of the parameter
|
||||
:param parameter_type: Type of the parameter (e.g., string, integer)
|
||||
:param description: Human-readable description of what the parameter does
|
||||
:param required: Whether this parameter is required for tool invocation
|
||||
:param items: Type of the elements when parameter_type is array
|
||||
:param title: (Optional) Title of the parameter
|
||||
:param default: (Optional) Default value for the parameter if not provided
|
||||
"""
|
||||
|
||||
name: str
|
||||
parameter_type: str
|
||||
description: str
|
||||
required: bool = Field(default=True)
|
||||
items: dict | None = None
|
||||
title: str | None = None
|
||||
default: Any | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Tool(Resource):
|
||||
"""A tool that can be invoked by agents.
|
||||
|
||||
:param type: Type of resource, always 'tool'
|
||||
:param toolgroup_id: ID of the tool group this tool belongs to
|
||||
:param description: Human-readable description of what the tool does
|
||||
:param input_schema: JSON Schema for the tool's input parameters
|
||||
:param output_schema: JSON Schema for the tool's output
|
||||
:param metadata: (Optional) Additional metadata about the tool
|
||||
"""
|
||||
|
||||
type: Literal[ResourceType.tool] = ResourceType.tool
|
||||
toolgroup_id: str
|
||||
description: str
|
||||
input_schema: dict[str, Any] | None = None
|
||||
output_schema: dict[str, Any] | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolDef(BaseModel):
|
||||
"""Tool definition used in runtime contexts.
|
||||
|
@ -70,8 +28,10 @@ class ToolDef(BaseModel):
|
|||
:param input_schema: (Optional) JSON Schema for tool inputs (MCP inputSchema)
|
||||
:param output_schema: (Optional) JSON Schema for tool outputs (MCP outputSchema)
|
||||
:param metadata: (Optional) Additional metadata about the tool
|
||||
:param toolgroup_id: (Optional) ID of the tool group this tool belongs to
|
||||
"""
|
||||
|
||||
toolgroup_id: str | None = None
|
||||
name: str
|
||||
description: str | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
|
@ -126,7 +86,7 @@ class ToolInvocationResult(BaseModel):
|
|||
|
||||
|
||||
class ToolStore(Protocol):
|
||||
async def get_tool(self, tool_name: str) -> Tool: ...
|
||||
async def get_tool(self, tool_name: str) -> ToolDef: ...
|
||||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
|
||||
|
||||
|
||||
|
@ -139,15 +99,6 @@ class ListToolGroupsResponse(BaseModel):
|
|||
data: list[ToolGroup]
|
||||
|
||||
|
||||
class ListToolsResponse(BaseModel):
|
||||
"""Response containing a list of tools.
|
||||
|
||||
:param data: List of tools
|
||||
"""
|
||||
|
||||
data: list[Tool]
|
||||
|
||||
|
||||
class ListToolDefsResponse(BaseModel):
|
||||
"""Response containing a list of tool definitions.
|
||||
|
||||
|
@ -198,11 +149,11 @@ class ToolGroups(Protocol):
|
|||
...
|
||||
|
||||
@webmethod(route="/tools", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolDefsResponse:
|
||||
"""List tools with optional tool group.
|
||||
|
||||
:param toolgroup_id: The ID of the tool group to list tools for.
|
||||
:returns: A ListToolsResponse.
|
||||
:returns: A ListToolDefsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -210,11 +161,11 @@ class ToolGroups(Protocol):
|
|||
async def get_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
) -> Tool:
|
||||
) -> ToolDef:
|
||||
"""Get a tool by its name.
|
||||
|
||||
:param tool_name: The name of the tool to get.
|
||||
:returns: A Tool.
|
||||
:returns: A ToolDef.
|
||||
"""
|
||||
...
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ from llama_stack.apis.safety import Safety
|
|||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
||||
from llama_stack.apis.shields import Shield, ShieldInput
|
||||
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
||||
from llama_stack.apis.tools import ToolGroup, ToolGroupInput, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.access_control.datatypes import AccessRule
|
||||
|
@ -84,15 +84,11 @@ class BenchmarkWithOwner(Benchmark, ResourceWithOwner):
|
|||
pass
|
||||
|
||||
|
||||
class ToolWithOwner(Tool, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
class ToolGroupWithOwner(ToolGroup, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup
|
||||
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | ToolGroup
|
||||
|
||||
RoutableObjectWithProvider = Annotated[
|
||||
ModelWithOwner
|
||||
|
@ -101,7 +97,6 @@ RoutableObjectWithProvider = Annotated[
|
|||
| DatasetWithOwner
|
||||
| ScoringFnWithOwner
|
||||
| BenchmarkWithOwner
|
||||
| ToolWithOwner
|
||||
| ToolGroupWithOwner,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
|
|
@ -11,7 +11,7 @@ from llama_stack.apis.common.content_types import (
|
|||
InterleavedContent,
|
||||
)
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolsResponse,
|
||||
ListToolDefsResponse,
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
RAGQueryResult,
|
||||
|
@ -86,6 +86,6 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
) -> ListToolsResponse:
|
||||
) -> ListToolDefsResponse:
|
||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||
return await self.routing_table.list_tools(tool_group_id)
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import Any
|
|||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.errors import ToolGroupNotFoundError
|
||||
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ListToolGroupsResponse, ToolDef, ToolGroup, ToolGroups
|
||||
from llama_stack.core.datatypes import AuthenticationRequiredError, ToolGroupWithOwner
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
|
@ -27,7 +27,7 @@ def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name
|
|||
|
||||
|
||||
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||
toolgroups_to_tools: dict[str, list[Tool]] = {}
|
||||
toolgroups_to_tools: dict[str, list[ToolDef]] = {}
|
||||
tool_to_toolgroup: dict[str, str] = {}
|
||||
|
||||
# overridden
|
||||
|
@ -43,7 +43,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
routing_key = self.tool_to_toolgroup[routing_key]
|
||||
return await super().get_provider_impl(routing_key, provider_id)
|
||||
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolDefsResponse:
|
||||
if toolgroup_id:
|
||||
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
|
||||
toolgroup_id = group_id
|
||||
|
@ -68,31 +68,19 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
continue
|
||||
all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier])
|
||||
|
||||
return ListToolsResponse(data=all_tools)
|
||||
return ListToolDefsResponse(data=all_tools)
|
||||
|
||||
async def _index_tools(self, toolgroup: ToolGroup):
|
||||
provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
|
||||
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
|
||||
|
||||
# TODO: kill this Tool vs ToolDef distinction
|
||||
tooldefs = tooldefs_response.data
|
||||
tools = []
|
||||
for t in tooldefs:
|
||||
tools.append(
|
||||
Tool(
|
||||
identifier=t.name,
|
||||
toolgroup_id=toolgroup.identifier,
|
||||
description=t.description or "",
|
||||
input_schema=t.input_schema,
|
||||
output_schema=t.output_schema,
|
||||
metadata=t.metadata,
|
||||
provider_id=toolgroup.provider_id,
|
||||
)
|
||||
)
|
||||
t.toolgroup_id = toolgroup.identifier
|
||||
|
||||
self.toolgroups_to_tools[toolgroup.identifier] = tools
|
||||
for tool in tools:
|
||||
self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier
|
||||
self.toolgroups_to_tools[toolgroup.identifier] = tooldefs
|
||||
for tool in tooldefs:
|
||||
self.tool_to_toolgroup[tool.name] = toolgroup.identifier
|
||||
|
||||
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
|
||||
|
@ -103,12 +91,12 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
raise ToolGroupNotFoundError(toolgroup_id)
|
||||
return tool_group
|
||||
|
||||
async def get_tool(self, tool_name: str) -> Tool:
|
||||
async def get_tool(self, tool_name: str) -> ToolDef:
|
||||
if tool_name in self.tool_to_toolgroup:
|
||||
toolgroup_id = self.tool_to_toolgroup[tool_name]
|
||||
tools = self.toolgroups_to_tools[toolgroup_id]
|
||||
for tool in tools:
|
||||
if tool.identifier == tool_name:
|
||||
if tool.name == tool_name:
|
||||
return tool
|
||||
raise ValueError(f"Tool '{tool_name}' not found")
|
||||
|
||||
|
@ -133,7 +121,6 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
# baked in some of the code and tests right now.
|
||||
if not toolgroup.mcp_endpoint:
|
||||
await self._index_tools(toolgroup)
|
||||
return toolgroup
|
||||
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||
await self.unregister_object(await self.get_tool_group(toolgroup_id))
|
||||
|
|
|
@ -36,7 +36,7 @@ class DistributionRegistry(Protocol):
|
|||
|
||||
|
||||
REGISTER_PREFIX = "distributions:registry"
|
||||
KEY_VERSION = "v9"
|
||||
KEY_VERSION = "v10"
|
||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||
|
||||
|
||||
|
|
|
@ -81,7 +81,7 @@ def tool_chat_page():
|
|||
|
||||
for toolgroup_id in toolgroup_selection:
|
||||
tools = client.tools.list(toolgroup_id=toolgroup_id)
|
||||
grouped_tools[toolgroup_id] = [tool.identifier for tool in tools]
|
||||
grouped_tools[toolgroup_id] = [tool.name for tool in tools]
|
||||
total_tools += len(tools)
|
||||
|
||||
st.markdown(f"Active Tools: 🛠 {total_tools}")
|
||||
|
|
|
@ -37,14 +37,7 @@ RecursiveType = Primitive | list[Primitive] | dict[str, Primitive]
|
|||
class ToolCall(BaseModel):
|
||||
call_id: str
|
||||
tool_name: BuiltinTool | str
|
||||
# Plan is to deprecate the Dict in favor of a JSON string
|
||||
# that is parsed on the client side instead of trying to manage
|
||||
# the recursive type here.
|
||||
# Making this a union so that client side can start prepping for this change.
|
||||
# Eventually, we will remove both the Dict and arguments_json field,
|
||||
# and arguments will just be a str
|
||||
arguments: str | dict[str, RecursiveType]
|
||||
arguments_json: str | None = None
|
||||
arguments: str
|
||||
|
||||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
|
|
|
@ -232,8 +232,7 @@ class ChatFormat:
|
|||
ToolCall(
|
||||
call_id=call_id,
|
||||
tool_name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
arguments_json=json.dumps(tool_arguments),
|
||||
arguments=json.dumps(tool_arguments),
|
||||
)
|
||||
)
|
||||
content = ""
|
||||
|
|
|
@ -298,8 +298,7 @@ class ChatFormat:
|
|||
ToolCall(
|
||||
call_id=call_id,
|
||||
tool_name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
arguments_json=json.dumps(tool_arguments),
|
||||
arguments=json.dumps(tool_arguments),
|
||||
)
|
||||
)
|
||||
content = ""
|
||||
|
|
|
@ -804,61 +804,34 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
[t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data]
|
||||
)
|
||||
raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}")
|
||||
if input_tool_name is not None and not any(tool.identifier == input_tool_name for tool in tools.data):
|
||||
if input_tool_name is not None and not any(tool.name == input_tool_name for tool in tools.data):
|
||||
raise ValueError(
|
||||
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
|
||||
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.name for tool in tools.data])}"
|
||||
)
|
||||
|
||||
for tool_def in tools.data:
|
||||
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
|
||||
identifier: str | BuiltinTool | None = tool_def.identifier
|
||||
identifier: str | BuiltinTool | None = tool_def.name
|
||||
if identifier == "web_search":
|
||||
identifier = BuiltinTool.brave_search
|
||||
else:
|
||||
identifier = BuiltinTool(identifier)
|
||||
else:
|
||||
# add if tool_name is unspecified or the tool_def identifier is the same as the tool_name
|
||||
if input_tool_name in (None, tool_def.identifier):
|
||||
identifier = tool_def.identifier
|
||||
if input_tool_name in (None, tool_def.name):
|
||||
identifier = tool_def.name
|
||||
else:
|
||||
identifier = None
|
||||
|
||||
if tool_name_to_def.get(identifier, None):
|
||||
raise ValueError(f"Tool {identifier} already exists")
|
||||
if identifier:
|
||||
# Build JSON Schema from tool parameters
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for param in tool_def.parameters:
|
||||
param_schema = {
|
||||
"type": param.parameter_type,
|
||||
"description": param.description,
|
||||
}
|
||||
if param.default is not None:
|
||||
param_schema["default"] = param.default
|
||||
if param.items is not None:
|
||||
param_schema["items"] = param.items
|
||||
if param.title is not None:
|
||||
param_schema["title"] = param.title
|
||||
|
||||
properties[param.name] = param_schema
|
||||
|
||||
if param.required:
|
||||
required.append(param.name)
|
||||
|
||||
input_schema = {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
}
|
||||
|
||||
tool_name_to_def[tool_def.identifier] = ToolDefinition(
|
||||
tool_name_to_def[identifier] = ToolDefinition(
|
||||
tool_name=identifier,
|
||||
description=tool_def.description,
|
||||
input_schema=input_schema,
|
||||
input_schema=tool_def.input_schema,
|
||||
)
|
||||
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
|
||||
tool_name_to_args[identifier] = toolgroup_to_args.get(toolgroup_name, {})
|
||||
|
||||
self.tool_defs, self.tool_name_to_args = (
|
||||
list(tool_name_to_def.values()),
|
||||
|
|
|
@ -33,7 +33,6 @@ from llama_stack.apis.tools import (
|
|||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import (
|
||||
|
@ -301,13 +300,16 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
ToolDef(
|
||||
name="knowledge_search",
|
||||
description="Search for information in a database.",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to search for. Can be a natural language sentence or keywords.",
|
||||
parameter_type="string",
|
||||
),
|
||||
],
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The query to search for. Can be a natural language sentence or keywords.",
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
|
|
@ -99,8 +99,7 @@ def _convert_to_vllm_tool_calls_in_response(
|
|||
ToolCall(
|
||||
call_id=call.id,
|
||||
tool_name=call.function.name,
|
||||
arguments=json.loads(call.function.arguments),
|
||||
arguments_json=call.function.arguments,
|
||||
arguments=call.function.arguments,
|
||||
)
|
||||
for call in tool_calls
|
||||
]
|
||||
|
@ -160,7 +159,6 @@ def _process_vllm_chat_completion_end_of_stream(
|
|||
for _index, tool_call_buf in sorted(tool_call_bufs.items()):
|
||||
args_str = tool_call_buf.arguments or "{}"
|
||||
try:
|
||||
args = json.loads(args_str)
|
||||
chunks.append(
|
||||
ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
|
@ -169,8 +167,7 @@ def _process_vllm_chat_completion_end_of_stream(
|
|||
tool_call=ToolCall(
|
||||
call_id=tool_call_buf.call_id,
|
||||
tool_name=tool_call_buf.tool_name,
|
||||
arguments=args,
|
||||
arguments_json=args_str,
|
||||
arguments=args_str,
|
||||
),
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
|
|
|
@ -15,7 +15,6 @@ from llama_stack.apis.tools import (
|
|||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
|
@ -57,13 +56,16 @@ class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsReq
|
|||
ToolDef(
|
||||
name="web_search",
|
||||
description="Search the web using Bing Search API",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to search for",
|
||||
parameter_type="string",
|
||||
)
|
||||
],
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The query to search for",
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
|
@ -14,7 +14,6 @@ from llama_stack.apis.tools import (
|
|||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
|
@ -56,13 +55,16 @@ class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRe
|
|||
ToolDef(
|
||||
name="web_search",
|
||||
description="Search the web for information",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to search for",
|
||||
parameter_type="string",
|
||||
)
|
||||
],
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The query to search for",
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
built_in_type=BuiltinTool.brave_search,
|
||||
)
|
||||
]
|
||||
|
|
|
@ -15,7 +15,6 @@ from llama_stack.apis.tools import (
|
|||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
|
@ -56,13 +55,16 @@ class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
|
|||
ToolDef(
|
||||
name="web_search",
|
||||
description="Search the web for information",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to search for",
|
||||
parameter_type="string",
|
||||
)
|
||||
],
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The query to search for",
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
|
@ -15,7 +15,6 @@ from llama_stack.apis.tools import (
|
|||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
|
@ -57,13 +56,16 @@ class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
|
|||
ToolDef(
|
||||
name="wolfram_alpha",
|
||||
description="Query WolframAlpha for computational knowledge",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to compute",
|
||||
parameter_type="string",
|
||||
)
|
||||
],
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The query to compute",
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
|
@ -538,18 +538,13 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
|
|||
if isinstance(tool_name, BuiltinTool):
|
||||
tool_name = tool_name.value
|
||||
|
||||
# arguments_json can be None, so attempt it first and fall back to arguments
|
||||
if hasattr(tc, "arguments_json") and tc.arguments_json:
|
||||
arguments = tc.arguments_json
|
||||
else:
|
||||
arguments = json.dumps(tc.arguments)
|
||||
result["tool_calls"].append(
|
||||
{
|
||||
"id": tc.call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"arguments": arguments,
|
||||
"arguments": tc.arguments,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
@ -685,8 +680,7 @@ def convert_tool_call(
|
|||
valid_tool_call = ToolCall(
|
||||
call_id=tool_call.id,
|
||||
tool_name=tool_call.function.name,
|
||||
arguments=json.loads(tool_call.function.arguments),
|
||||
arguments_json=tool_call.function.arguments,
|
||||
arguments=tool_call.function.arguments,
|
||||
)
|
||||
except Exception:
|
||||
return UnparseableToolCall(
|
||||
|
@ -897,8 +891,7 @@ def _convert_openai_tool_calls(
|
|||
ToolCall(
|
||||
call_id=call.id,
|
||||
tool_name=call.function.name,
|
||||
arguments=json.loads(call.function.arguments),
|
||||
arguments_json=call.function.arguments,
|
||||
arguments=call.function.arguments,
|
||||
)
|
||||
for call in tool_calls
|
||||
]
|
||||
|
@ -1184,8 +1177,7 @@ async def convert_openai_chat_completion_stream(
|
|||
tool_call = ToolCall(
|
||||
call_id=buffer["call_id"],
|
||||
tool_name=buffer["name"],
|
||||
arguments=arguments,
|
||||
arguments_json=buffer["arguments"],
|
||||
arguments=buffer["arguments"],
|
||||
)
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
|
@ -1418,7 +1410,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
openai_tool_call = OpenAIChoiceDeltaToolCall(
|
||||
index=0,
|
||||
function=OpenAIChoiceDeltaToolCallFunction(
|
||||
arguments=tool_call.arguments_json,
|
||||
arguments=tool_call.arguments,
|
||||
),
|
||||
)
|
||||
delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue