mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
Updated get_headers_from_request
This commit is contained in:
parent
dbe41d9510
commit
e8cb52683d
1 changed files with 28 additions and 6 deletions
|
|
@ -44,8 +44,10 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
|
||||||
# this endpoint should be retrieved by getting the tool group right?
|
# this endpoint should be retrieved by getting the tool group right?
|
||||||
if mcp_endpoint is None:
|
if mcp_endpoint is None:
|
||||||
raise ValueError("mcp_endpoint is required")
|
raise ValueError("mcp_endpoint is required")
|
||||||
headers = await self.get_headers_from_request(mcp_endpoint.uri)
|
headers, authorization = await self.get_headers_from_request(mcp_endpoint.uri)
|
||||||
return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=headers)
|
return await list_mcp_tools(
|
||||||
|
endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization
|
||||||
|
)
|
||||||
|
|
||||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||||
tool = await self.tool_store.get_tool(tool_name)
|
tool = await self.tool_store.get_tool(tool_name)
|
||||||
|
|
@ -55,24 +57,44 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
|
||||||
if urlparse(endpoint).scheme not in ("http", "https"):
|
if urlparse(endpoint).scheme not in ("http", "https"):
|
||||||
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
|
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
|
||||||
|
|
||||||
headers = await self.get_headers_from_request(endpoint)
|
headers, authorization = await self.get_headers_from_request(endpoint)
|
||||||
return await invoke_mcp_tool(
|
return await invoke_mcp_tool(
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
authorization=authorization,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
|
async def get_headers_from_request(
|
||||||
|
self, mcp_endpoint_uri: str
|
||||||
|
) -> tuple[dict[str, str], str | None]:
|
||||||
|
"""
|
||||||
|
Extract headers and authorization from request provider data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (headers_dict, authorization_token)
|
||||||
|
- headers_dict: All headers except Authorization
|
||||||
|
- authorization_token: Token from Authorization header (with "Bearer " prefix removed), or None
|
||||||
|
"""
|
||||||
|
|
||||||
def canonicalize_uri(uri: str) -> str:
|
def canonicalize_uri(uri: str) -> str:
|
||||||
return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}"
|
return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}"
|
||||||
|
|
||||||
headers = {}
|
headers = {}
|
||||||
|
authorization = None
|
||||||
|
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data and provider_data.mcp_headers:
|
if provider_data and provider_data.mcp_headers:
|
||||||
for uri, values in provider_data.mcp_headers.items():
|
for uri, values in provider_data.mcp_headers.items():
|
||||||
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
|
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
|
||||||
continue
|
continue
|
||||||
headers.update(values)
|
# Extract Authorization header separately for security
|
||||||
return headers
|
for key, value in values.items():
|
||||||
|
if key.lower() == "authorization":
|
||||||
|
# Remove "Bearer " prefix if present
|
||||||
|
authorization = value.removeprefix("Bearer ").strip()
|
||||||
|
else:
|
||||||
|
headers[key] = value
|
||||||
|
|
||||||
|
return headers, authorization
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue