final changes

This commit is contained in:
Dinesh Yeduguru 2024-12-19 20:59:47 -08:00
parent a297d27d48
commit de065a60f2
11 changed files with 142 additions and 108 deletions

View file

@ -48,30 +48,34 @@ class ToolDef(BaseModel):
@json_schema_type
class MCPToolGroup(BaseModel):
class MCPToolGroupDef(BaseModel):
"""
A tool group that is defined by in a model context protocol server.
Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information.
"""
type: Literal["model_context_protocol"] = "model_context_protocol"
name: str
endpoint: URL
@json_schema_type
class UserDefinedToolGroup(BaseModel):
class UserDefinedToolGroupDef(BaseModel):
type: Literal["user_defined"] = "user_defined"
name: str
tools: List[ToolDef]
ToolGroup = register_schema(
Annotated[Union[MCPToolGroup, UserDefinedToolGroup], Field(discriminator="type")],
ToolGroupDef = register_schema(
Annotated[
Union[MCPToolGroupDef, UserDefinedToolGroupDef], Field(discriminator="type")
],
name="ToolGroup",
)
class ToolGroup(Resource):
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
@json_schema_type
class ToolInvocationResult(BaseModel):
content: InterleavedContent
@ -80,34 +84,44 @@ class ToolInvocationResult(BaseModel):
class ToolStore(Protocol):
def get_tool(self, tool_id: str) -> Tool: ...
def get_tool(self, tool_name: str) -> Tool: ...
@runtime_checkable
@trace_protocol
class Tools(Protocol):
class ToolGroups(Protocol):
@webmethod(route="/toolgroups/register", method="POST")
async def register_tool_group(
self,
tool_group: ToolGroup,
tool_group_id: str,
tool_group: ToolGroupDef,
provider_id: Optional[str] = None,
) -> None:
"""Register a tool group"""
...
@webmethod(route="/tools/get", method="GET")
async def get_tool(
@webmethod(route="/toolgroups/get", method="GET")
async def get_tool_group(
self,
tool_id: str,
) -> Tool: ...
tool_group_id: str,
) -> ToolGroup: ...
@webmethod(route="/toolgroups/list", method="GET")
async def list_tool_groups(self) -> List[ToolGroup]:
"""List tool groups with optional provider"""
...
@webmethod(route="/tools/list", method="GET")
async def list_tools(self) -> List[Tool]:
"""List tools with optional provider"""
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
"""List tools with optional tool group"""
...
@webmethod(route="/tools/unregister", method="POST")
async def unregister_tool(self, tool_id: str) -> None:
"""Unregister a tool"""
@webmethod(route="/tools/get", method="GET")
async def get_tool(self, tool_name: str) -> Tool: ...
@webmethod(route="/toolgroups/unregister", method="POST")
async def unregister_tool_group(self, tool_group_id: str) -> None:
"""Unregister a tool group"""
...
@ -117,11 +131,11 @@ class ToolRuntime(Protocol):
tool_store: ToolStore
@webmethod(route="/tool-runtime/discover", method="POST")
async def discover_tools(self, tool_group: ToolGroup) -> List[Tool]: ...
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: ...
@webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool(
self, tool_id: str, args: Dict[str, Any]
self, tool_name: str, args: Dict[str, Any]
) -> ToolInvocationResult:
"""Run a tool with the given arguments"""
...