mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
precommit
This commit is contained in:
parent
66ca51ac0d
commit
1a6cb7041d
9 changed files with 43 additions and 26 deletions
|
|
@ -43,7 +43,9 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
routing_key = self.tool_to_toolgroup[routing_key]
|
||||
return await super().get_provider_impl(routing_key, provider_id)
|
||||
|
||||
async def list_tools(self, toolgroup_id: str | None = None, authorization: str | None = None) -> ListToolDefsResponse:
|
||||
async def list_tools(
|
||||
self, toolgroup_id: str | None = None, authorization: str | None = None
|
||||
) -> ListToolDefsResponse:
|
||||
if toolgroup_id:
|
||||
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
|
||||
toolgroup_id = group_id
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ class MCPProviderDataValidator(BaseModel):
|
|||
|
||||
This validator is kept for future provider-data extensions if needed.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -25,9 +25,7 @@ from .config import MCPProviderConfig
|
|||
logger = get_logger(__name__, category="tools")
|
||||
|
||||
|
||||
class ModelContextProtocolToolRuntimeImpl(
|
||||
ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
|
||||
):
|
||||
class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
|
||||
self.config = config
|
||||
|
||||
|
|
@ -52,9 +50,7 @@ class ModelContextProtocolToolRuntimeImpl(
|
|||
|
||||
# Use authorization parameter for MCP servers that require auth
|
||||
headers = {}
|
||||
return await list_mcp_tools(
|
||||
endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization
|
||||
)
|
||||
return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization)
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
|
||||
|
|
@ -76,9 +72,7 @@ class ModelContextProtocolToolRuntimeImpl(
|
|||
authorization=authorization,
|
||||
)
|
||||
|
||||
async def get_headers_from_request(
|
||||
self, mcp_endpoint_uri: str
|
||||
) -> tuple[dict[str, str], str | None]:
|
||||
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> tuple[dict[str, str], str | None]:
|
||||
"""
|
||||
Placeholder method for extracting headers and authorization.
|
||||
|
||||
|
|
|
|||
|
|
@ -885,7 +885,9 @@ def patch_inference_clients():
|
|||
OllamaAsyncClient.list = patched_ollama_list
|
||||
|
||||
# Create patched methods for tool runtimes
|
||||
async def patched_tavily_invoke_tool(self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None):
|
||||
async def patched_tavily_invoke_tool(
|
||||
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
|
||||
):
|
||||
return await _patched_tool_invoke_method(
|
||||
_original_methods["tavily_invoke_tool"], "tavily", self, tool_name, kwargs, authorization=authorization
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue