mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 18:22:41 +00:00
address feedback
This commit is contained in:
parent
e08b7f4432
commit
a7a55748ca
8 changed files with 21 additions and 24 deletions
|
@ -7,7 +7,7 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import BuiltinTool, ToolPromptFormat
|
from llama_models.llama3.api.datatypes import ToolPromptFormat
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Protocol, runtime_checkable
|
from typing_extensions import Protocol, runtime_checkable
|
||||||
|
@ -40,7 +40,6 @@ class Tool(Resource):
|
||||||
tool_host: ToolHost
|
tool_host: ToolHost
|
||||||
description: str
|
description: str
|
||||||
parameters: List[ToolParameter]
|
parameters: List[ToolParameter]
|
||||||
built_in_type: Optional[BuiltinTool] = None
|
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||||
default=ToolPromptFormat.json
|
default=ToolPromptFormat.json
|
||||||
|
@ -53,7 +52,6 @@ class ToolDef(BaseModel):
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
parameters: Optional[List[ToolParameter]] = None
|
parameters: Optional[List[ToolParameter]] = None
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
built_in_type: Optional[BuiltinTool] = None
|
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||||
default=ToolPromptFormat.json
|
default=ToolPromptFormat.json
|
||||||
)
|
)
|
||||||
|
@ -130,6 +128,7 @@ class ToolGroups(Protocol):
|
||||||
class ToolRuntime(Protocol):
|
class ToolRuntime(Protocol):
|
||||||
tool_store: ToolStore
|
tool_store: ToolStore
|
||||||
|
|
||||||
|
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
||||||
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
|
|
|
@ -527,7 +527,6 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
provider_resource_id=tool_def.name,
|
provider_resource_id=tool_def.name,
|
||||||
metadata=tool_def.metadata,
|
metadata=tool_def.metadata,
|
||||||
tool_host=tool_host,
|
tool_host=tool_host,
|
||||||
built_in_type=tool_def.built_in_type,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
|
|
|
@ -78,6 +78,7 @@ def make_random_string(length: int = 8):
|
||||||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||||
MEMORY_QUERY_TOOL = "query_memory"
|
MEMORY_QUERY_TOOL = "query_memory"
|
||||||
WEB_SEARCH_TOOL = "web_search"
|
WEB_SEARCH_TOOL = "web_search"
|
||||||
|
MEMORY_GROUP = "builtin::memory"
|
||||||
|
|
||||||
|
|
||||||
class ChatAgent(ShieldRunnerMixin):
|
class ChatAgent(ShieldRunnerMixin):
|
||||||
|
@ -741,16 +742,24 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
continue
|
continue
|
||||||
tools = await self.tool_groups_api.list_tools(tool_group_id=toolgroup_name)
|
tools = await self.tool_groups_api.list_tools(tool_group_id=toolgroup_name)
|
||||||
for tool_def in tools:
|
for tool_def in tools:
|
||||||
if tool_def.built_in_type:
|
if (
|
||||||
if tool_def_map.get(tool_def.built_in_type, None):
|
toolgroup_name.startswith("builtin")
|
||||||
raise ValueError(
|
and toolgroup_name != MEMORY_GROUP
|
||||||
f"Tool {tool_def.built_in_type} already exists"
|
):
|
||||||
)
|
tool_name = tool_def.identifier
|
||||||
|
built_in_type = BuiltinTool.brave_search
|
||||||
|
if tool_name == "web_search":
|
||||||
|
built_in_type = BuiltinTool.brave_search
|
||||||
|
else:
|
||||||
|
built_in_type = BuiltinTool(tool_name)
|
||||||
|
|
||||||
tool_def_map[tool_def.built_in_type] = ToolDefinition(
|
if tool_def_map.get(built_in_type, None):
|
||||||
tool_name=tool_def.built_in_type
|
raise ValueError(f"Tool {built_in_type} already exists")
|
||||||
|
|
||||||
|
tool_def_map[built_in_type] = ToolDefinition(
|
||||||
|
tool_name=built_in_type
|
||||||
)
|
)
|
||||||
tool_to_group[tool_def.built_in_type] = tool_def.toolgroup_id
|
tool_to_group[built_in_type] = tool_def.toolgroup_id
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if tool_def_map.get(tool_def.identifier, None):
|
if tool_def_map.get(tool_def.identifier, None):
|
||||||
|
|
|
@ -198,7 +198,7 @@ class MockToolGroupsAPI:
|
||||||
toolgroup_id=MEMORY_TOOLGROUP,
|
toolgroup_id=MEMORY_TOOLGROUP,
|
||||||
tool_host=ToolHost.client,
|
tool_host=ToolHost.client,
|
||||||
description="Mock tool",
|
description="Mock tool",
|
||||||
provider_id="mock_provider",
|
provider_id="builtin::memory",
|
||||||
parameters=[],
|
parameters=[],
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
@ -208,10 +208,9 @@ class MockToolGroupsAPI:
|
||||||
identifier="code_interpreter",
|
identifier="code_interpreter",
|
||||||
provider_resource_id="code_interpreter",
|
provider_resource_id="code_interpreter",
|
||||||
toolgroup_id=CODE_INTERPRETER_TOOLGROUP,
|
toolgroup_id=CODE_INTERPRETER_TOOLGROUP,
|
||||||
built_in_type=BuiltinTool.code_interpreter,
|
|
||||||
tool_host=ToolHost.client,
|
tool_host=ToolHost.client,
|
||||||
description="Mock tool",
|
description="Mock tool",
|
||||||
provider_id="mock_provider",
|
provider_id="builtin::code_interpreter",
|
||||||
parameters=[],
|
parameters=[],
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
|
@ -9,8 +9,6 @@ import logging
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
Tool,
|
Tool,
|
||||||
|
@ -58,7 +56,6 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
parameter_type="string",
|
parameter_type="string",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
built_in_type=BuiltinTool.code_interpreter,
|
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,6 @@ import json
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
|
@ -65,7 +64,6 @@ class BingSearchToolRuntimeImpl(
|
||||||
parameter_type="string",
|
parameter_type="string",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
built_in_type=BuiltinTool.brave_search,
|
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,6 @@ import json
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
|
@ -64,7 +63,6 @@ class TavilySearchToolRuntimeImpl(
|
||||||
parameter_type="string",
|
parameter_type="string",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
built_in_type=BuiltinTool.brave_search,
|
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,6 @@ import json
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
|
@ -65,7 +64,6 @@ class WolframAlphaToolRuntimeImpl(
|
||||||
parameter_type="string",
|
parameter_type="string",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
built_in_type=BuiltinTool.wolfram_alpha,
|
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue