precommit

This commit is contained in:
Omar Abdelwahab 2025-11-12 19:02:54 -08:00
parent 66ca51ac0d
commit 1a6cb7041d
9 changed files with 43 additions and 26 deletions

View file

@ -43,7 +43,9 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
routing_key = self.tool_to_toolgroup[routing_key]
return await super().get_provider_impl(routing_key, provider_id)
async def list_tools(self, toolgroup_id: str | None = None, authorization: str | None = None) -> ListToolDefsResponse:
async def list_tools(
self, toolgroup_id: str | None = None, authorization: str | None = None
) -> ListToolDefsResponse:
if toolgroup_id:
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
toolgroup_id = group_id

View file

@ -19,6 +19,7 @@ class MCPProviderDataValidator(BaseModel):
This validator is kept for future provider-data extensions if needed.
"""
pass

View file

@ -25,9 +25,7 @@ from .config import MCPProviderConfig
logger = get_logger(__name__, category="tools")
class ModelContextProtocolToolRuntimeImpl(
ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
):
class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
self.config = config
@ -52,9 +50,7 @@ class ModelContextProtocolToolRuntimeImpl(
# Use authorization parameter for MCP servers that require auth
headers = {}
return await list_mcp_tools(
endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization
)
return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization)
async def invoke_tool(
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
@ -76,9 +72,7 @@ class ModelContextProtocolToolRuntimeImpl(
authorization=authorization,
)
async def get_headers_from_request(
self, mcp_endpoint_uri: str
) -> tuple[dict[str, str], str | None]:
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> tuple[dict[str, str], str | None]:
"""
Placeholder method for extracting headers and authorization.

View file

@ -885,7 +885,9 @@ def patch_inference_clients():
OllamaAsyncClient.list = patched_ollama_list
# Create patched methods for tool runtimes
async def patched_tavily_invoke_tool(self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None):
async def patched_tavily_invoke_tool(
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
):
return await _patched_tool_invoke_method(
_original_methods["tavily_invoke_tool"], "tavily", self, tool_name, kwargs, authorization=authorization
)