precommit

This commit is contained in:
Omar Abdelwahab 2025-11-07 12:14:42 -08:00
parent 445135b8cc
commit ccb870c8fb
6 changed files with 49 additions and 112 deletions

View file

@ -25,9 +25,7 @@ from .config import MCPProviderConfig
logger = get_logger(__name__, category="tools")
class ModelContextProtocolToolRuntimeImpl(
ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
):
class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
self.config = config
@ -47,13 +45,9 @@ class ModelContextProtocolToolRuntimeImpl(
if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required")
headers, authorization = await self.get_headers_from_request(mcp_endpoint.uri)
return await list_mcp_tools(
endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization
)
return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization)
async def invoke_tool(
self, tool_name: str, kwargs: dict[str, Any]
) -> ToolInvocationResult:
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)
if tool.metadata is None or tool.metadata.get("endpoint") is None:
raise ValueError(f"Tool {tool_name} does not have metadata")
@ -70,9 +64,7 @@ class ModelContextProtocolToolRuntimeImpl(
authorization=authorization,
)
async def get_headers_from_request(
self, mcp_endpoint_uri: str
) -> tuple[dict[str, str], str | None]:
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> tuple[dict[str, str], str | None]:
"""
Extract headers and authorization from request provider data.