final changes

This commit is contained in:
Dinesh Yeduguru 2024-12-19 20:59:47 -08:00
parent a297d27d48
commit de065a60f2
11 changed files with 142 additions and 108 deletions

View file

@ -40,7 +40,7 @@ class Api(Enum):
datasets = "datasets"
scoring_functions = "scoring_functions"
eval_tasks = "eval_tasks"
tools = "tools"
tool_groups = "tool_groups"
# built-in API
inspect = "inspect"

View file

@ -8,7 +8,7 @@ from typing import Any, Dict, List
import requests
from llama_stack.apis.tools import Tool, ToolGroup, ToolInvocationResult, ToolRuntime
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
@ -42,11 +42,11 @@ class BraveSearchToolRuntimeImpl(
)
return provider_data.api_key
async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]:
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
raise NotImplementedError("Brave search tool group not supported")
async def invoke_tool(
self, tool_id: str, args: Dict[str, Any]
self, tool_name: str, args: Dict[str, Any]
) -> ToolInvocationResult:
api_key = self._get_api_key()
url = "https://api.search.brave.com/res/v1/web/search"

View file

@ -8,14 +8,15 @@ from typing import Any, Dict, List
from urllib.parse import urlparse
from llama_stack.apis.tools import (
MCPToolGroup,
Tool,
ToolGroup,
MCPToolGroupDef,
ToolDef,
ToolGroupDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from mcp import ClientSession
from mcp.client.sse import sse_client
@ -29,8 +30,8 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async def initialize(self):
pass
async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]:
if not isinstance(tool_group, MCPToolGroup):
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]:
if not isinstance(tool_group, MCPToolGroupDef):
raise ValueError(f"Unsupported tool group type: {type(tool_group)}")
tools = []
@ -51,25 +52,23 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
)
)
tools.append(
Tool(
identifier=tool.name,
ToolDef(
name=tool.name,
description=tool.description,
tool_group=tool_group.name,
parameters=parameters,
metadata={
"endpoint": tool_group.endpoint.uri,
},
provider_resource_id=tool.name,
)
)
return tools
async def invoke_tool(
self, tool_id: str, args: Dict[str, Any]
self, tool_name: str, args: Dict[str, Any]
) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_id)
tool = await self.tool_store.get_tool(tool_name)
if tool.metadata is None or tool.metadata.get("endpoint") is None:
raise ValueError(f"Tool {tool_id} does not have metadata")
raise ValueError(f"Tool {tool_name} does not have metadata")
endpoint = tool.metadata.get("endpoint")
if urlparse(endpoint).scheme not in ("http", "https"):
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")