From dcb3dc42116668688322fb62c732c232e9490c83 Mon Sep 17 00:00:00 2001 From: Omar Abdelwahab Date: Wed, 5 Nov 2025 11:41:02 -0800 Subject: [PATCH] raising an error when the authentication field is present in the authorization field and in the header --- .../meta_reference/responses/streaming.py | 12 ++++--- .../meta_reference/responses/tool_executor.py | 17 ++++++---- .../responses/test_mcp_authentication.py | 34 ++++++++----------- 3 files changed, 34 insertions(+), 29 deletions(-) diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index 152d67617..029ba7b89 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -1082,11 +1082,15 @@ class StreamingResponseOrchestrator: # Prepare headers with authorization from tool config headers = dict(mcp_tool.headers or {}) if mcp_tool.authorization: - # Don't override existing Authorization header (case-insensitive check) + # Check if Authorization header already exists (case-insensitive check) existing_keys_lower = {k.lower() for k in headers.keys()} - if "authorization" not in existing_keys_lower: - # OAuth access token - add "Bearer " prefix - headers["Authorization"] = f"Bearer {mcp_tool.authorization}" + if "authorization" in existing_keys_lower: + raise ValueError( + "Cannot specify Authorization in both 'headers' and 'authorization' fields. " + "Please use only the 'authorization' field." + ) + # OAuth access token - add "Bearer " prefix + headers["Authorization"] = f"Bearer {mcp_tool.authorization}" async with tracing.span("list_mcp_tools", attributes): tool_defs = await list_mcp_tools( diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py index a2490d17b..d0dc1557a 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py @@ -302,11 +302,15 @@ class ToolExecutor: # Prepare headers with authorization from tool config headers = dict(mcp_tool.headers or {}) if mcp_tool.authorization: - # Don't override existing Authorization header (case-insensitive check) + # Check if Authorization header already exists (case-insensitive check) existing_keys_lower = {k.lower() for k in headers.keys()} - if "authorization" not in existing_keys_lower: - # OAuth access token - add "Bearer " prefix - headers["Authorization"] = f"Bearer {mcp_tool.authorization}" + if "authorization" in existing_keys_lower: + raise ValueError( + "Cannot specify Authorization in both 'headers' and 'authorization' fields. " + "Please use only the 'authorization' field." + ) + # OAuth access token - add "Bearer " prefix + headers["Authorization"] = f"Bearer {mcp_tool.authorization}" async with tracing.span("invoke_mcp_tool", attributes): result = await invoke_mcp_tool( @@ -369,7 +373,6 @@ class ToolExecutor: mcp_completed_event = OpenAIResponseObjectStreamResponseMcpCallCompleted( sequence_number=sequence_number, ) - yield ToolExecutionResult(stream_event=mcp_completed_event, sequence_number=sequence_number) elif function_name == "web_search": sequence_number += 1 web_completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted( @@ -485,6 +488,8 @@ class ToolExecutor: input_message = OpenAIToolMessageParam(content=msg_content, tool_call_id=tool_call_id) # type: ignore[arg-type] else: text = str(error_exc) if error_exc else "Tool execution failed" - input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id) + input_message = OpenAIToolMessageParam( + content=text, tool_call_id=tool_call_id + ) return message, input_message diff --git a/tests/integration/responses/test_mcp_authentication.py b/tests/integration/responses/test_mcp_authentication.py index 5473684bb..bfcf578ac 100644 --- a/tests/integration/responses/test_mcp_authentication.py +++ b/tests/integration/responses/test_mcp_authentication.py @@ -91,40 +91,36 @@ def test_mcp_authorization_different_token(compat_client, text_model_id): assert response.output[1].error is None -def test_mcp_authorization_fallback_to_headers(compat_client, text_model_id): - """Test that authorization parameter doesn't override existing Authorization header.""" +def test_mcp_authorization_error_when_both_provided(compat_client, text_model_id): + """Test that providing both headers['Authorization'] and authorization field raises an error.""" if not isinstance(compat_client, LlamaStackAsLibraryClient): pytest.skip("in-process MCP server is only supported in library client") - # Headers should take precedence - this test uses headers auth - test_token = "headers-token-123" + test_token = "test-token-123" with make_mcp_server(required_auth_token=test_token) as mcp_server_info: tools = setup_mcp_tools( [ { "type": "mcp", - "server_label": "headers-mcp", + "server_label": "both-auth-mcp", "server_url": "", "headers": {"Authorization": f"Bearer {test_token}"}, - "authorization": "should-not-override", # Just the token + "authorization": "should-cause-error", # This should trigger an error } ], mcp_server_info, ) - # Create response - headers should take precedence - response = compat_client.responses.create( - model=text_model_id, - input="What is the boiling point of myawesomeliquid?", - tools=tools, - stream=False, - ) - - # Verify operations succeeded with headers auth - assert len(response.output) >= 3 - assert response.output[0].type == "mcp_list_tools" - assert response.output[1].type == "mcp_call" - assert response.output[1].error is None + # Create response - should raise ValueError + with pytest.raises( + ValueError, match="Cannot specify Authorization in both 'headers' and 'authorization' fields" + ): + compat_client.responses.create( + model=text_model_id, + input="What is the boiling point of myawesomeliquid?", + tools=tools, + stream=False, + ) def test_mcp_authorization_backward_compatibility(compat_client, text_model_id):