mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
fix: MCP authorization parameter implementation (#4052)
# What does this PR do? Adding a user-facing `authorization ` parameter to MCP tool definitions that allows users to explicitly configure credentials per MCP server, addressing GitHub Issue #4034 in a secure manner. ## Test Plan tests/integration/responses/test_mcp_authentication.py --------- Co-authored-by: Omar Abdelwahab <omara@fb.com> Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
parent
dc49ad3f89
commit
eb545034ab
34 changed files with 5205 additions and 62 deletions
|
|
@ -34,16 +34,16 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
logger.debug("ToolRuntimeRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None) -> Any:
|
||||
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||
provider = await self.routing_table.get_provider_impl(tool_name)
|
||||
return await provider.invoke_tool(
|
||||
tool_name=tool_name,
|
||||
kwargs=kwargs,
|
||||
authorization=authorization,
|
||||
)
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None, authorization: str | None = None
|
||||
) -> ListToolDefsResponse:
|
||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||
return await self.routing_table.list_tools(tool_group_id)
|
||||
return await self.routing_table.list_tools(tool_group_id, authorization=authorization)
|
||||
|
|
|
|||
|
|
@ -49,7 +49,9 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
routing_key = self.tool_to_toolgroup[routing_key]
|
||||
return await super().get_provider_impl(routing_key, provider_id)
|
||||
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolDefsResponse:
|
||||
async def list_tools(
|
||||
self, toolgroup_id: str | None = None, authorization: str | None = None
|
||||
) -> ListToolDefsResponse:
|
||||
if toolgroup_id:
|
||||
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
|
||||
toolgroup_id = group_id
|
||||
|
|
@ -61,7 +63,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
for toolgroup in toolgroups:
|
||||
if toolgroup.identifier not in self.toolgroups_to_tools:
|
||||
try:
|
||||
await self._index_tools(toolgroup)
|
||||
await self._index_tools(toolgroup, authorization=authorization)
|
||||
except AuthenticationRequiredError:
|
||||
# Send authentication errors back to the client so it knows
|
||||
# that it needs to supply credentials for remote MCP servers.
|
||||
|
|
@ -76,9 +78,11 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
|
||||
return ListToolDefsResponse(data=all_tools)
|
||||
|
||||
async def _index_tools(self, toolgroup: ToolGroup):
|
||||
async def _index_tools(self, toolgroup: ToolGroup, authorization: str | None = None):
|
||||
provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
|
||||
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
|
||||
tooldefs_response = await provider_impl.list_runtime_tools(
|
||||
toolgroup.identifier, toolgroup.mcp_endpoint, authorization=authorization
|
||||
)
|
||||
|
||||
tooldefs = tooldefs_response.data
|
||||
for t in tooldefs:
|
||||
|
|
|
|||
|
|
@ -257,6 +257,19 @@ class OpenAIResponsesImpl:
|
|||
stream = bool(stream)
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||
|
||||
# Validate MCP tools: ensure Authorization header is not passed via headers dict
|
||||
if tools:
|
||||
from llama_stack_api.openai_responses import OpenAIResponseInputToolMCP
|
||||
|
||||
for tool in tools:
|
||||
if isinstance(tool, OpenAIResponseInputToolMCP) and tool.headers:
|
||||
for key in tool.headers.keys():
|
||||
if key.lower() == "authorization":
|
||||
raise ValueError(
|
||||
"Authorization header cannot be passed via 'headers'. "
|
||||
"Please use the 'authorization' parameter instead."
|
||||
)
|
||||
|
||||
guardrail_ids = extract_guardrail_ids(guardrails) if guardrails else []
|
||||
|
||||
if conversation is not None:
|
||||
|
|
|
|||
|
|
@ -1091,10 +1091,12 @@ class StreamingResponseOrchestrator:
|
|||
"server_url": mcp_tool.server_url,
|
||||
"mcp_list_tools_id": list_id,
|
||||
}
|
||||
# List MCP tools with authorization from tool config
|
||||
async with tracing.span("list_mcp_tools", attributes):
|
||||
tool_defs = await list_mcp_tools(
|
||||
endpoint=mcp_tool.server_url,
|
||||
headers=mcp_tool.headers or {},
|
||||
headers=mcp_tool.headers,
|
||||
authorization=mcp_tool.authorization,
|
||||
)
|
||||
|
||||
# Create the MCP list tools message
|
||||
|
|
|
|||
|
|
@ -296,12 +296,14 @@ class ToolExecutor:
|
|||
"server_url": mcp_tool.server_url,
|
||||
"tool_name": function_name,
|
||||
}
|
||||
# Invoke MCP tool with authorization from tool config
|
||||
async with tracing.span("invoke_mcp_tool", attributes):
|
||||
result = await invoke_mcp_tool(
|
||||
endpoint=mcp_tool.server_url,
|
||||
headers=mcp_tool.headers or {},
|
||||
tool_name=function_name,
|
||||
kwargs=tool_kwargs,
|
||||
headers=mcp_tool.headers,
|
||||
authorization=mcp_tool.authorization,
|
||||
)
|
||||
elif function_name == "knowledge_search":
|
||||
response_file_search_tool = (
|
||||
|
|
|
|||
|
|
@ -276,7 +276,10 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime):
|
|||
)
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
self,
|
||||
tool_group_id: str | None = None,
|
||||
mcp_endpoint: URL | None = None,
|
||||
authorization: str | None = None,
|
||||
) -> ListToolDefsResponse:
|
||||
# Parameters are not listed since these methods are not yet invoked automatically
|
||||
# by the LLM. The method is only implemented so things like /tools can list without
|
||||
|
|
@ -304,7 +307,9 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime):
|
|||
]
|
||||
)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
|
||||
) -> ToolInvocationResult:
|
||||
vector_store_ids = kwargs.get("vector_store_ids", [])
|
||||
query_config = kwargs.get("query_config")
|
||||
if query_config:
|
||||
|
|
|
|||
|
|
@ -49,7 +49,10 @@ class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsReq
|
|||
return provider_data.bing_search_api_key
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
self,
|
||||
tool_group_id: str | None = None,
|
||||
mcp_endpoint: URL | None = None,
|
||||
authorization: str | None = None,
|
||||
) -> ListToolDefsResponse:
|
||||
return ListToolDefsResponse(
|
||||
data=[
|
||||
|
|
@ -70,7 +73,9 @@ class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsReq
|
|||
]
|
||||
)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
|
||||
) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": api_key,
|
||||
|
|
|
|||
|
|
@ -48,7 +48,10 @@ class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRe
|
|||
return provider_data.brave_search_api_key
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
self,
|
||||
tool_group_id: str | None = None,
|
||||
mcp_endpoint: URL | None = None,
|
||||
authorization: str | None = None,
|
||||
) -> ListToolDefsResponse:
|
||||
return ListToolDefsResponse(
|
||||
data=[
|
||||
|
|
@ -70,7 +73,9 @@ class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRe
|
|||
]
|
||||
)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
|
||||
) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
url = "https://api.search.brave.com/res/v1/web/search"
|
||||
headers = {
|
||||
|
|
|
|||
|
|
@ -10,8 +10,14 @@ from pydantic import BaseModel
|
|||
|
||||
|
||||
class MCPProviderDataValidator(BaseModel):
|
||||
# mcp_endpoint => dict of headers to send
|
||||
mcp_headers: dict[str, dict[str, str]] | None = None
|
||||
"""
|
||||
Validator for MCP provider-specific data passed via request headers.
|
||||
|
||||
Phase 1: Support old header-based authentication for backward compatibility.
|
||||
In Phase 2, this will be deprecated in favor of the authorization parameter.
|
||||
"""
|
||||
|
||||
mcp_headers: dict[str, dict[str, str]] | None = None # Map of URI -> headers dict
|
||||
|
||||
|
||||
class MCPProviderConfig(BaseModel):
|
||||
|
|
|
|||
|
|
@ -39,15 +39,29 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
|
|||
return
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
self,
|
||||
tool_group_id: str | None = None,
|
||||
mcp_endpoint: URL | None = None,
|
||||
authorization: str | None = None,
|
||||
) -> ListToolDefsResponse:
|
||||
# this endpoint should be retrieved by getting the tool group right?
|
||||
if mcp_endpoint is None:
|
||||
raise ValueError("mcp_endpoint is required")
|
||||
headers = await self.get_headers_from_request(mcp_endpoint.uri)
|
||||
return await list_mcp_tools(mcp_endpoint.uri, headers)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
# Phase 1: Support both old header-based auth AND new authorization parameter
|
||||
# Get headers and auth from provider data (old approach)
|
||||
provider_headers, provider_auth = await self.get_headers_from_request(mcp_endpoint.uri)
|
||||
|
||||
# New authorization parameter takes precedence over provider data
|
||||
final_authorization = authorization or provider_auth
|
||||
|
||||
return await list_mcp_tools(
|
||||
endpoint=mcp_endpoint.uri, headers=provider_headers, authorization=final_authorization
|
||||
)
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
|
||||
) -> ToolInvocationResult:
|
||||
tool = await self.tool_store.get_tool(tool_name)
|
||||
if tool.metadata is None or tool.metadata.get("endpoint") is None:
|
||||
raise ValueError(f"Tool {tool_name} does not have metadata")
|
||||
|
|
@ -55,19 +69,57 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
|
|||
if urlparse(endpoint).scheme not in ("http", "https"):
|
||||
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
|
||||
|
||||
headers = await self.get_headers_from_request(endpoint)
|
||||
return await invoke_mcp_tool(endpoint, headers, tool_name, kwargs)
|
||||
# Phase 1: Support both old header-based auth AND new authorization parameter
|
||||
# Get headers and auth from provider data (old approach)
|
||||
provider_headers, provider_auth = await self.get_headers_from_request(endpoint)
|
||||
|
||||
# New authorization parameter takes precedence over provider data
|
||||
final_authorization = authorization or provider_auth
|
||||
|
||||
return await invoke_mcp_tool(
|
||||
endpoint=endpoint,
|
||||
tool_name=tool_name,
|
||||
kwargs=kwargs,
|
||||
headers=provider_headers,
|
||||
authorization=final_authorization,
|
||||
)
|
||||
|
||||
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> tuple[dict[str, str], str | None]:
|
||||
"""
|
||||
Extract headers and authorization from request provider data (Phase 1 backward compatibility).
|
||||
|
||||
Phase 1: Temporarily allows Authorization to be passed via mcp_headers for backward compatibility.
|
||||
Phase 2: Will enforce that Authorization should use the dedicated authorization parameter instead.
|
||||
|
||||
Returns:
|
||||
Tuple of (headers_dict, authorization_token)
|
||||
- headers_dict: All headers except Authorization
|
||||
- authorization_token: Token from Authorization header (with "Bearer " prefix removed), or None
|
||||
"""
|
||||
|
||||
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
|
||||
def canonicalize_uri(uri: str) -> str:
|
||||
return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}"
|
||||
|
||||
headers = {}
|
||||
authorization = None
|
||||
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data and provider_data.mcp_headers:
|
||||
if provider_data and hasattr(provider_data, "mcp_headers") and provider_data.mcp_headers:
|
||||
for uri, values in provider_data.mcp_headers.items():
|
||||
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
|
||||
continue
|
||||
headers.update(values)
|
||||
return headers
|
||||
|
||||
# Phase 1: Extract Authorization from mcp_headers for backward compatibility
|
||||
# (Phase 2 will reject this and require the dedicated authorization parameter)
|
||||
for key in values.keys():
|
||||
if key.lower() == "authorization":
|
||||
# Extract authorization token and strip "Bearer " prefix if present
|
||||
auth_value = values[key]
|
||||
if auth_value.startswith("Bearer "):
|
||||
authorization = auth_value[7:] # Remove "Bearer " prefix
|
||||
else:
|
||||
authorization = auth_value
|
||||
else:
|
||||
headers[key] = values[key]
|
||||
|
||||
return headers, authorization
|
||||
|
|
|
|||
|
|
@ -48,7 +48,10 @@ class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
|
|||
return provider_data.tavily_search_api_key
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
self,
|
||||
tool_group_id: str | None = None,
|
||||
mcp_endpoint: URL | None = None,
|
||||
authorization: str | None = None,
|
||||
) -> ListToolDefsResponse:
|
||||
return ListToolDefsResponse(
|
||||
data=[
|
||||
|
|
@ -69,7 +72,9 @@ class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
|
|||
]
|
||||
)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
|
||||
) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
|
|
|
|||
|
|
@ -49,7 +49,10 @@ class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
|
|||
return provider_data.wolfram_alpha_api_key
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
self,
|
||||
tool_group_id: str | None = None,
|
||||
mcp_endpoint: URL | None = None,
|
||||
authorization: str | None = None,
|
||||
) -> ListToolDefsResponse:
|
||||
return ListToolDefsResponse(
|
||||
data=[
|
||||
|
|
@ -70,7 +73,9 @@ class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
|
|||
]
|
||||
)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
|
||||
) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
params = {
|
||||
"input": kwargs["query"],
|
||||
|
|
|
|||
|
|
@ -30,6 +30,40 @@ from llama_stack_api import (
|
|||
|
||||
logger = get_logger(__name__, category="tools")
|
||||
|
||||
|
||||
def prepare_mcp_headers(base_headers: dict[str, str] | None, authorization: str | None) -> dict[str, str]:
|
||||
"""
|
||||
Prepare headers for MCP requests with authorization support.
|
||||
|
||||
Args:
|
||||
base_headers: Base headers dictionary (can be None)
|
||||
authorization: OAuth access token (without "Bearer " prefix)
|
||||
|
||||
Returns:
|
||||
Headers dictionary with Authorization header if token provided
|
||||
|
||||
Raises:
|
||||
ValueError: If Authorization header is specified in the headers dict (security risk)
|
||||
"""
|
||||
headers = dict(base_headers or {})
|
||||
|
||||
# Security check: reject any Authorization header in the headers dict
|
||||
# Users must use the authorization parameter instead to avoid security risks
|
||||
existing_keys_lower = {k.lower() for k in headers.keys()}
|
||||
if "authorization" in existing_keys_lower:
|
||||
raise ValueError(
|
||||
"For security reasons, Authorization header cannot be passed via 'headers'. "
|
||||
"Please use the 'authorization' parameter instead."
|
||||
)
|
||||
|
||||
# Add Authorization header if token provided
|
||||
if authorization:
|
||||
# OAuth access token - add "Bearer " prefix
|
||||
headers["Authorization"] = f"Bearer {authorization}"
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
protocol_cache = TTLDict(ttl_seconds=3600)
|
||||
|
||||
|
||||
|
|
@ -112,9 +146,29 @@ async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerat
|
|||
raise
|
||||
|
||||
|
||||
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
|
||||
async def list_mcp_tools(
|
||||
endpoint: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
authorization: str | None = None,
|
||||
) -> ListToolDefsResponse:
|
||||
"""List tools available from an MCP server.
|
||||
|
||||
Args:
|
||||
endpoint: MCP server endpoint URL
|
||||
headers: Optional base headers to include
|
||||
authorization: Optional OAuth access token (just the token, not "Bearer <token>")
|
||||
|
||||
Returns:
|
||||
List of tool definitions from the MCP server
|
||||
|
||||
Raises:
|
||||
ValueError: If Authorization is found in the headers parameter
|
||||
"""
|
||||
# Prepare headers with authorization handling
|
||||
final_headers = prepare_mcp_headers(headers, authorization)
|
||||
|
||||
tools = []
|
||||
async with client_wrapper(endpoint, headers) as session:
|
||||
async with client_wrapper(endpoint, final_headers) as session:
|
||||
tools_result = await session.list_tools()
|
||||
for tool in tools_result.tools:
|
||||
tools.append(
|
||||
|
|
@ -132,9 +186,31 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs
|
|||
|
||||
|
||||
async def invoke_mcp_tool(
|
||||
endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any]
|
||||
endpoint: str,
|
||||
tool_name: str,
|
||||
kwargs: dict[str, Any],
|
||||
headers: dict[str, str] | None = None,
|
||||
authorization: str | None = None,
|
||||
) -> ToolInvocationResult:
|
||||
async with client_wrapper(endpoint, headers) as session:
|
||||
"""Invoke an MCP tool with the given arguments.
|
||||
|
||||
Args:
|
||||
endpoint: MCP server endpoint URL
|
||||
tool_name: Name of the tool to invoke
|
||||
kwargs: Tool invocation arguments
|
||||
headers: Optional base headers to include
|
||||
authorization: Optional OAuth access token (just the token, not "Bearer <token>")
|
||||
|
||||
Returns:
|
||||
Tool invocation result with content and error information
|
||||
|
||||
Raises:
|
||||
ValueError: If Authorization header is found in the headers parameter
|
||||
"""
|
||||
# Prepare headers with authorization handling
|
||||
final_headers = prepare_mcp_headers(headers, authorization)
|
||||
|
||||
async with client_wrapper(endpoint, final_headers) as session:
|
||||
result = await session.call_tool(tool_name, kwargs)
|
||||
|
||||
content: list[InterleavedContentItem] = []
|
||||
|
|
|
|||
|
|
@ -609,14 +609,14 @@ def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]])
|
|||
|
||||
|
||||
async def _patched_tool_invoke_method(
|
||||
original_method, provider_name: str, self, tool_name: str, kwargs: dict[str, Any]
|
||||
original_method, provider_name: str, self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
|
||||
):
|
||||
"""Patched version of tool runtime invoke_tool method for recording/replay."""
|
||||
global _current_mode, _current_storage
|
||||
|
||||
if _current_mode == APIRecordingMode.LIVE or _current_storage is None:
|
||||
# Normal operation
|
||||
return await original_method(self, tool_name, kwargs)
|
||||
return await original_method(self, tool_name, kwargs, authorization=authorization)
|
||||
|
||||
request_hash = normalize_tool_request(provider_name, tool_name, kwargs)
|
||||
|
||||
|
|
@ -634,7 +634,7 @@ async def _patched_tool_invoke_method(
|
|||
|
||||
if _current_mode in (APIRecordingMode.RECORD, APIRecordingMode.RECORD_IF_MISSING):
|
||||
# Make the tool call and record it
|
||||
result = await original_method(self, tool_name, kwargs)
|
||||
result = await original_method(self, tool_name, kwargs, authorization=authorization)
|
||||
|
||||
request_data = {
|
||||
"test_id": get_test_context(),
|
||||
|
|
@ -885,9 +885,11 @@ def patch_inference_clients():
|
|||
OllamaAsyncClient.list = patched_ollama_list
|
||||
|
||||
# Create patched methods for tool runtimes
|
||||
async def patched_tavily_invoke_tool(self, tool_name: str, kwargs: dict[str, Any]):
|
||||
async def patched_tavily_invoke_tool(
|
||||
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
|
||||
):
|
||||
return await _patched_tool_invoke_method(
|
||||
_original_methods["tavily_invoke_tool"], "tavily", self, tool_name, kwargs
|
||||
_original_methods["tavily_invoke_tool"], "tavily", self, tool_name, kwargs, authorization=authorization
|
||||
)
|
||||
|
||||
# Apply tool runtime patches
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue