precommit

This commit is contained in:
Omar Abdelwahab 2025-11-12 19:02:54 -08:00
parent 66ca51ac0d
commit 1a6cb7041d
9 changed files with 43 additions and 26 deletions

View file

@ -1881,6 +1881,13 @@ paths:
required: false required: false
schema: schema:
$ref: '#/components/schemas/URL' $ref: '#/components/schemas/URL'
- name: authorization
in: query
description: >-
(Optional) OAuth access token for authenticating with the MCP server.
required: false
schema:
type: string
deprecated: false deprecated: false
/v1/toolgroups: /v1/toolgroups:
get: get:
@ -9086,6 +9093,10 @@ components:
- type: object - type: object
description: >- description: >-
A dictionary of arguments to pass to the tool. A dictionary of arguments to pass to the tool.
authorization:
type: string
description: >-
(Optional) OAuth access token for authenticating with the MCP server.
additionalProperties: false additionalProperties: false
required: required:
- tool_name - tool_name

View file

@ -1878,6 +1878,13 @@ paths:
required: false required: false
schema: schema:
$ref: '#/components/schemas/URL' $ref: '#/components/schemas/URL'
- name: authorization
in: query
description: >-
(Optional) OAuth access token for authenticating with the MCP server.
required: false
schema:
type: string
deprecated: false deprecated: false
/v1/toolgroups: /v1/toolgroups:
get: get:
@ -8370,6 +8377,10 @@ components:
- type: object - type: object
description: >- description: >-
A dictionary of arguments to pass to the tool. A dictionary of arguments to pass to the tool.
authorization:
type: string
description: >-
(Optional) OAuth access token for authenticating with the MCP server.
additionalProperties: false additionalProperties: false
required: required:
- tool_name - tool_name

View file

@ -1881,6 +1881,13 @@ paths:
required: false required: false
schema: schema:
$ref: '#/components/schemas/URL' $ref: '#/components/schemas/URL'
- name: authorization
in: query
description: >-
(Optional) OAuth access token for authenticating with the MCP server.
required: false
schema:
type: string
deprecated: false deprecated: false
/v1/toolgroups: /v1/toolgroups:
get: get:
@ -9086,6 +9093,10 @@ components:
- type: object - type: object
description: >- description: >-
A dictionary of arguments to pass to the tool. A dictionary of arguments to pass to the tool.
authorization:
type: string
description: >-
(Optional) OAuth access token for authenticating with the MCP server.
additionalProperties: false additionalProperties: false
required: required:
- tool_name - tool_name

View file

@ -43,7 +43,9 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
routing_key = self.tool_to_toolgroup[routing_key] routing_key = self.tool_to_toolgroup[routing_key]
return await super().get_provider_impl(routing_key, provider_id) return await super().get_provider_impl(routing_key, provider_id)
async def list_tools(self, toolgroup_id: str | None = None, authorization: str | None = None) -> ListToolDefsResponse: async def list_tools(
self, toolgroup_id: str | None = None, authorization: str | None = None
) -> ListToolDefsResponse:
if toolgroup_id: if toolgroup_id:
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id): if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
toolgroup_id = group_id toolgroup_id = group_id

View file

@ -19,6 +19,7 @@ class MCPProviderDataValidator(BaseModel):
This validator is kept for future provider-data extensions if needed. This validator is kept for future provider-data extensions if needed.
""" """
pass pass

View file

@ -25,9 +25,7 @@ from .config import MCPProviderConfig
logger = get_logger(__name__, category="tools") logger = get_logger(__name__, category="tools")
class ModelContextProtocolToolRuntimeImpl( class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
):
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]): def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
self.config = config self.config = config
@ -52,9 +50,7 @@ class ModelContextProtocolToolRuntimeImpl(
# Use authorization parameter for MCP servers that require auth # Use authorization parameter for MCP servers that require auth
headers = {} headers = {}
return await list_mcp_tools( return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization)
endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization
)
async def invoke_tool( async def invoke_tool(
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
@ -76,9 +72,7 @@ class ModelContextProtocolToolRuntimeImpl(
authorization=authorization, authorization=authorization,
) )
async def get_headers_from_request( async def get_headers_from_request(self, mcp_endpoint_uri: str) -> tuple[dict[str, str], str | None]:
self, mcp_endpoint_uri: str
) -> tuple[dict[str, str], str | None]:
""" """
Placeholder method for extracting headers and authorization. Placeholder method for extracting headers and authorization.

View file

@ -885,7 +885,9 @@ def patch_inference_clients():
OllamaAsyncClient.list = patched_ollama_list OllamaAsyncClient.list = patched_ollama_list
# Create patched methods for tool runtimes # Create patched methods for tool runtimes
async def patched_tavily_invoke_tool(self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None): async def patched_tavily_invoke_tool(
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
):
return await _patched_tool_invoke_method( return await _patched_tool_invoke_method(
_original_methods["tavily_invoke_tool"], "tavily", self, tool_name, kwargs, authorization=authorization _original_methods["tavily_invoke_tool"], "tavily", self, tool_name, kwargs, authorization=authorization
) )

View file

@ -9,8 +9,6 @@ Integration tests for inference/chat completion with JSON Schema-based tools.
Tests that tools pass through correctly to various LLM providers. Tests that tools pass through correctly to various LLM providers.
""" """
import json
import pytest import pytest
from llama_stack.core.library_client import LlamaStackAsLibraryClient from llama_stack.core.library_client import LlamaStackAsLibraryClient

View file

@ -9,8 +9,6 @@ Integration tests for MCP tools with complex JSON Schema support.
Tests $ref, $defs, and other JSON Schema features through MCP integration. Tests $ref, $defs, and other JSON Schema features through MCP integration.
""" """
import json
import pytest import pytest
from llama_stack.core.library_client import LlamaStackAsLibraryClient from llama_stack.core.library_client import LlamaStackAsLibraryClient
@ -123,8 +121,6 @@ class TestMCPSchemaPreservation:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
# List runtime tools # List runtime tools
response = llama_stack_client.tool_runtime.list_tools( response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
@ -163,7 +159,6 @@ class TestMCPSchemaPreservation:
provider_id="model-context-protocol", provider_id="model-context-protocol",
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
# List tools # List tools
response = llama_stack_client.tool_runtime.list_tools( response = llama_stack_client.tool_runtime.list_tools(
@ -210,8 +205,6 @@ class TestMCPSchemaPreservation:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
response = llama_stack_client.tool_runtime.list_tools( response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
authorization=AUTH_TOKEN, authorization=AUTH_TOKEN,
@ -254,8 +247,6 @@ class TestMCPToolInvocation:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
# List tools to populate the tool index # List tools to populate the tool index
llama_stack_client.tool_runtime.list_tools( llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
@ -297,8 +288,6 @@ class TestMCPToolInvocation:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
# List tools to populate the tool index # List tools to populate the tool index
llama_stack_client.tool_runtime.list_tools( llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
@ -350,8 +339,6 @@ class TestAgentWithMCPTools:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
tools_list = llama_stack_client.tools.list( tools_list = llama_stack_client.tools.list(
toolgroup_id=test_toolgroup_id, toolgroup_id=test_toolgroup_id,
authorization=AUTH_TOKEN, authorization=AUTH_TOKEN,