mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-21 09:52:25 +00:00
final changes
This commit is contained in:
parent
a297d27d48
commit
de065a60f2
11 changed files with 142 additions and 108 deletions
|
|
@ -40,7 +40,7 @@ class Api(Enum):
|
|||
datasets = "datasets"
|
||||
scoring_functions = "scoring_functions"
|
||||
eval_tasks = "eval_tasks"
|
||||
tools = "tools"
|
||||
tool_groups = "tool_groups"
|
||||
|
||||
# built-in API
|
||||
inspect = "inspect"
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from typing import Any, Dict, List
|
|||
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.tools import Tool, ToolGroup, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
||||
|
|
@ -42,11 +42,11 @@ class BraveSearchToolRuntimeImpl(
|
|||
)
|
||||
return provider_data.api_key
|
||||
|
||||
async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]:
|
||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
|
||||
raise NotImplementedError("Brave search tool group not supported")
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_id: str, args: Dict[str, Any]
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
url = "https://api.search.brave.com/res/v1/web/search"
|
||||
|
|
|
|||
|
|
@ -8,14 +8,15 @@ from typing import Any, Dict, List
|
|||
from urllib.parse import urlparse
|
||||
|
||||
from llama_stack.apis.tools import (
|
||||
MCPToolGroup,
|
||||
Tool,
|
||||
ToolGroup,
|
||||
MCPToolGroupDef,
|
||||
ToolDef,
|
||||
ToolGroupDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
|
|
@ -29,8 +30,8 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]:
|
||||
if not isinstance(tool_group, MCPToolGroup):
|
||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]:
|
||||
if not isinstance(tool_group, MCPToolGroupDef):
|
||||
raise ValueError(f"Unsupported tool group type: {type(tool_group)}")
|
||||
|
||||
tools = []
|
||||
|
|
@ -51,25 +52,23 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
)
|
||||
)
|
||||
tools.append(
|
||||
Tool(
|
||||
identifier=tool.name,
|
||||
ToolDef(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
tool_group=tool_group.name,
|
||||
parameters=parameters,
|
||||
metadata={
|
||||
"endpoint": tool_group.endpoint.uri,
|
||||
},
|
||||
provider_resource_id=tool.name,
|
||||
)
|
||||
)
|
||||
return tools
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_id: str, args: Dict[str, Any]
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
tool = await self.tool_store.get_tool(tool_id)
|
||||
tool = await self.tool_store.get_tool(tool_name)
|
||||
if tool.metadata is None or tool.metadata.get("endpoint") is None:
|
||||
raise ValueError(f"Tool {tool_id} does not have metadata")
|
||||
raise ValueError(f"Tool {tool_name} does not have metadata")
|
||||
endpoint = tool.metadata.get("endpoint")
|
||||
if urlparse(endpoint).scheme not in ("http", "https"):
|
||||
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue