Updated get_headers_from_request

This commit is contained in:
Omar Abdelwahab 2025-11-06 11:41:33 -08:00
parent dbe41d9510
commit e8cb52683d

View file

@ -44,8 +44,10 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
# this endpoint should be retrieved by getting the tool group right?
if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required")
headers = await self.get_headers_from_request(mcp_endpoint.uri)
return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=headers)
headers, authorization = await self.get_headers_from_request(mcp_endpoint.uri)
return await list_mcp_tools(
endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization
)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)
@ -55,24 +57,44 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
if urlparse(endpoint).scheme not in ("http", "https"):
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
headers = await self.get_headers_from_request(endpoint)
headers, authorization = await self.get_headers_from_request(endpoint)
return await invoke_mcp_tool(
endpoint=endpoint,
tool_name=tool_name,
kwargs=kwargs,
headers=headers,
authorization=authorization,
)
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
async def get_headers_from_request(
self, mcp_endpoint_uri: str
) -> tuple[dict[str, str], str | None]:
"""
Extract headers and authorization from request provider data.
Returns:
Tuple of (headers_dict, authorization_token)
- headers_dict: All headers except Authorization
- authorization_token: Token from Authorization header (with "Bearer " prefix removed), or None
"""
def canonicalize_uri(uri: str) -> str:
return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}"
headers = {}
authorization = None
provider_data = self.get_request_provider_data()
if provider_data and provider_data.mcp_headers:
for uri, values in provider_data.mcp_headers.items():
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
continue
headers.update(values)
return headers
# Extract Authorization header separately for security
for key, value in values.items():
if key.lower() == "authorization":
# Remove "Bearer " prefix if present
authorization = value.removeprefix("Bearer ").strip()
else:
headers[key] = value
return headers, authorization