forked from phoenix-oss/llama-stack-mirror
feat: enable MCP execution in Responses impl (#2240)
## Test Plan ``` pytest -s -v 'tests/verifications/openai_api/test_responses.py' \ --provider=stack:together --model meta-llama/Llama-4-Scout-17B-16E-Instruct ```
This commit is contained in:
parent
66f09f24ed
commit
3faf1e4a79
15 changed files with 865 additions and 382 deletions
|
@ -10,6 +10,9 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||
|
||||
# NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably
|
||||
# take their YAML and generate this file automatically. Their YAML is available.
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseError(BaseModel):
|
||||
|
@ -79,16 +82,45 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
|
||||
arguments: str
|
||||
call_id: str
|
||||
name: str
|
||||
arguments: str
|
||||
type: Literal["function_call"] = "function_call"
|
||||
id: str | None = None
|
||||
status: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseOutputMessageMCPCall(BaseModel):
|
||||
id: str
|
||||
status: str
|
||||
type: Literal["mcp_call"] = "mcp_call"
|
||||
arguments: str
|
||||
name: str
|
||||
server_label: str
|
||||
error: str | None = None
|
||||
output: str | None = None
|
||||
|
||||
|
||||
class MCPListToolsTool(BaseModel):
|
||||
input_schema: dict[str, Any]
|
||||
name: str
|
||||
description: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseOutputMessageMCPListTools(BaseModel):
|
||||
id: str
|
||||
type: Literal["mcp_list_tools"] = "mcp_list_tools"
|
||||
server_label: str
|
||||
tools: list[MCPListToolsTool]
|
||||
|
||||
|
||||
OpenAIResponseOutput = Annotated[
|
||||
OpenAIResponseMessage | OpenAIResponseOutputMessageWebSearchToolCall | OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseMessage
|
||||
| OpenAIResponseOutputMessageWebSearchToolCall
|
||||
| OpenAIResponseOutputMessageFunctionToolCall
|
||||
| OpenAIResponseOutputMessageMCPCall
|
||||
| OpenAIResponseOutputMessageMCPListTools,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
||||
|
|
|
@ -14,6 +14,7 @@ from pydantic import BaseModel
|
|||
|
||||
from llama_stack.apis.agents import Order
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
AllowedToolsFilter,
|
||||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
OpenAIResponseInput,
|
||||
|
@ -22,7 +23,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
|
@ -51,11 +52,12 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.tools.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_responses")
|
||||
|
||||
|
@ -168,6 +170,15 @@ class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
|||
response: OpenAIResponseObject
|
||||
|
||||
|
||||
class ChatCompletionContext(BaseModel):
|
||||
model: str
|
||||
messages: list[OpenAIMessageParam]
|
||||
tools: list[ChatCompletionToolParam] | None = None
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
|
||||
stream: bool
|
||||
temperature: float | None
|
||||
|
||||
|
||||
class OpenAIResponsesImpl:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -255,13 +266,32 @@ class OpenAIResponsesImpl:
|
|||
temperature: float | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
):
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
|
||||
stream = False if stream is None else stream
|
||||
|
||||
# Huge TODO: we need to run this in a loop, until morale improves
|
||||
|
||||
# Create context to run "chat completion"
|
||||
input = await self._prepend_previous_response(input, previous_response_id)
|
||||
messages = await _convert_response_input_to_chat_messages(input)
|
||||
await self._prepend_instructions(messages, instructions)
|
||||
chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None
|
||||
chat_tools, mcp_tool_to_server, mcp_list_message = (
|
||||
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
|
||||
)
|
||||
if mcp_list_message:
|
||||
output_messages.append(mcp_list_message)
|
||||
|
||||
ctx = ChatCompletionContext(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=chat_tools,
|
||||
mcp_tool_to_server=mcp_tool_to_server,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
# Run inference
|
||||
chat_response = await self.inference_api.openai_chat_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -270,6 +300,7 @@ class OpenAIResponsesImpl:
|
|||
temperature=temperature,
|
||||
)
|
||||
|
||||
# Collect output
|
||||
if stream:
|
||||
# TODO: refactor this into a separate method that handles streaming
|
||||
chat_response_id = ""
|
||||
|
@ -328,11 +359,11 @@ class OpenAIResponsesImpl:
|
|||
# dump and reload to map to our pydantic types
|
||||
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
|
||||
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
# Execute tool calls if any
|
||||
for choice in chat_response.choices:
|
||||
if choice.message.tool_calls and tools:
|
||||
# Assume if the first tool is a function, all tools are functions
|
||||
if isinstance(tools[0], OpenAIResponseInputToolFunction):
|
||||
if tools[0].type == "function":
|
||||
for tool_call in choice.message.tool_calls:
|
||||
output_messages.append(
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
|
@ -344,11 +375,12 @@ class OpenAIResponsesImpl:
|
|||
)
|
||||
)
|
||||
else:
|
||||
output_messages.extend(
|
||||
await self._execute_tool_and_return_final_output(model, stream, choice, messages, temperature)
|
||||
)
|
||||
tool_messages = await self._execute_tool_and_return_final_output(choice, ctx)
|
||||
output_messages.extend(tool_messages)
|
||||
else:
|
||||
output_messages.append(await _convert_chat_choice_to_response_message(choice))
|
||||
|
||||
# Create response object
|
||||
response = OpenAIResponseObject(
|
||||
created_at=chat_response.created,
|
||||
id=f"resp-{uuid.uuid4()}",
|
||||
|
@ -359,9 +391,8 @@ class OpenAIResponsesImpl:
|
|||
)
|
||||
logger.debug(f"OpenAI Responses response: {response}")
|
||||
|
||||
# Store response if requested
|
||||
if store:
|
||||
# Store in kvstore
|
||||
|
||||
new_input_id = f"msg_{uuid.uuid4()}"
|
||||
if isinstance(input, str):
|
||||
# synthesize a message from the input string
|
||||
|
@ -403,7 +434,36 @@ class OpenAIResponsesImpl:
|
|||
|
||||
async def _convert_response_tools_to_chat_tools(
|
||||
self, tools: list[OpenAIResponseInputTool]
|
||||
) -> list[ChatCompletionToolParam]:
|
||||
) -> tuple[
|
||||
list[ChatCompletionToolParam],
|
||||
dict[str, OpenAIResponseInputToolMCP],
|
||||
OpenAIResponseOutput | None,
|
||||
]:
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
MCPListToolsTool,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
)
|
||||
from llama_stack.apis.tools.tools import Tool
|
||||
|
||||
mcp_tool_to_server = {}
|
||||
|
||||
def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
|
||||
tool_def = ToolDefinition(
|
||||
tool_name=tool_name,
|
||||
description=tool.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool.parameters
|
||||
},
|
||||
)
|
||||
return convert_tooldef_to_openai_tool(tool_def)
|
||||
|
||||
mcp_list_message = None
|
||||
chat_tools: list[ChatCompletionToolParam] = []
|
||||
for input_tool in tools:
|
||||
# TODO: Handle other tool types
|
||||
|
@ -412,91 +472,95 @@ class OpenAIResponsesImpl:
|
|||
elif input_tool.type == "web_search":
|
||||
tool_name = "web_search"
|
||||
tool = await self.tool_groups_api.get_tool(tool_name)
|
||||
tool_def = ToolDefinition(
|
||||
tool_name=tool_name,
|
||||
description=tool.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool.parameters
|
||||
},
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
chat_tools.append(make_openai_tool(tool_name, tool))
|
||||
elif input_tool.type == "mcp":
|
||||
always_allowed = None
|
||||
never_allowed = None
|
||||
if input_tool.allowed_tools:
|
||||
if isinstance(input_tool.allowed_tools, list):
|
||||
always_allowed = input_tool.allowed_tools
|
||||
elif isinstance(input_tool.allowed_tools, AllowedToolsFilter):
|
||||
always_allowed = input_tool.allowed_tools.always
|
||||
never_allowed = input_tool.allowed_tools.never
|
||||
|
||||
tool_defs = await list_mcp_tools(
|
||||
endpoint=input_tool.server_url,
|
||||
headers=input_tool.headers or {},
|
||||
)
|
||||
chat_tool = convert_tooldef_to_openai_tool(tool_def)
|
||||
chat_tools.append(chat_tool)
|
||||
|
||||
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
||||
id=f"mcp_list_{uuid.uuid4()}",
|
||||
status="completed",
|
||||
server_label=input_tool.server_label,
|
||||
tools=[],
|
||||
)
|
||||
for t in tool_defs.data:
|
||||
if never_allowed and t.name in never_allowed:
|
||||
continue
|
||||
if not always_allowed or t.name in always_allowed:
|
||||
chat_tools.append(make_openai_tool(t.name, t))
|
||||
if t.name in mcp_tool_to_server:
|
||||
raise ValueError(f"Duplicate tool name {t.name} found for server {input_tool.server_label}")
|
||||
mcp_tool_to_server[t.name] = input_tool
|
||||
mcp_list_message.tools.append(
|
||||
MCPListToolsTool(
|
||||
name=t.name,
|
||||
description=t.description,
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
p.name: {
|
||||
"type": p.parameter_type,
|
||||
"description": p.description,
|
||||
}
|
||||
for p in t.parameters
|
||||
},
|
||||
"required": [p.name for p in t.parameters if p.required],
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
||||
return chat_tools
|
||||
return chat_tools, mcp_tool_to_server, mcp_list_message
|
||||
|
||||
async def _execute_tool_and_return_final_output(
|
||||
self,
|
||||
model_id: str,
|
||||
stream: bool,
|
||||
choice: OpenAIChoice,
|
||||
messages: list[OpenAIMessageParam],
|
||||
temperature: float,
|
||||
ctx: ChatCompletionContext,
|
||||
) -> list[OpenAIResponseOutput]:
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
|
||||
# If the choice is not an assistant message, we don't need to execute any tools
|
||||
if not isinstance(choice.message, OpenAIAssistantMessageParam):
|
||||
return output_messages
|
||||
|
||||
# If the assistant message doesn't have any tool calls, we don't need to execute any tools
|
||||
if not choice.message.tool_calls:
|
||||
return output_messages
|
||||
|
||||
# Copy the messages list to avoid mutating the original list
|
||||
messages = messages.copy()
|
||||
next_turn_messages = ctx.messages.copy()
|
||||
|
||||
# Add the assistant message with tool_calls response to the messages list
|
||||
messages.append(choice.message)
|
||||
next_turn_messages.append(choice.message)
|
||||
|
||||
for tool_call in choice.message.tool_calls:
|
||||
tool_call_id = tool_call.id
|
||||
function = tool_call.function
|
||||
|
||||
# If for some reason the tool call doesn't have a function or id, we can't execute it
|
||||
if not function or not tool_call_id:
|
||||
continue
|
||||
|
||||
# TODO: telemetry spans for tool calls
|
||||
result = await self._execute_tool_call(function)
|
||||
|
||||
# Handle tool call failure
|
||||
if not result:
|
||||
output_messages.append(
|
||||
OpenAIResponseOutputMessageWebSearchToolCall(
|
||||
id=tool_call_id,
|
||||
status="failed",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
output_messages.append(
|
||||
OpenAIResponseOutputMessageWebSearchToolCall(
|
||||
id=tool_call_id,
|
||||
status="completed",
|
||||
),
|
||||
)
|
||||
|
||||
result_content = ""
|
||||
# TODO: handle other result content types and lists
|
||||
if isinstance(result.content, str):
|
||||
result_content = result.content
|
||||
messages.append(OpenAIToolMessageParam(content=result_content, tool_call_id=tool_call_id))
|
||||
tool_call_log, further_input = await self._execute_tool_call(tool_call, ctx)
|
||||
if tool_call_log:
|
||||
output_messages.append(tool_call_log)
|
||||
if further_input:
|
||||
next_turn_messages.append(further_input)
|
||||
|
||||
tool_results_chat_response = await self.inference_api.openai_chat_completion(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
model=ctx.model,
|
||||
messages=next_turn_messages,
|
||||
stream=ctx.stream,
|
||||
temperature=ctx.temperature,
|
||||
)
|
||||
# type cast to appease mypy
|
||||
# type cast to appease mypy: this is needed because we don't handle streaming properly :)
|
||||
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
|
||||
|
||||
# Huge TODO: these are NOT the final outputs, we must keep the loop going
|
||||
tool_final_outputs = [
|
||||
await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices
|
||||
]
|
||||
|
@ -506,15 +570,86 @@ class OpenAIResponsesImpl:
|
|||
|
||||
async def _execute_tool_call(
|
||||
self,
|
||||
function: OpenAIChatCompletionToolCallFunction,
|
||||
) -> ToolInvocationResult | None:
|
||||
if not function.name:
|
||||
return None
|
||||
function_args = json.loads(function.arguments) if function.arguments else {}
|
||||
logger.info(f"executing tool call: {function.name} with args: {function_args}")
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=function.name,
|
||||
kwargs=function_args,
|
||||
tool_call: OpenAIChatCompletionToolCall,
|
||||
ctx: ChatCompletionContext,
|
||||
) -> tuple[OpenAIResponseOutput | None, OpenAIMessageParam | None]:
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
logger.debug(f"tool call {function.name} completed with result: {result}")
|
||||
return result
|
||||
|
||||
tool_call_id = tool_call.id
|
||||
function = tool_call.function
|
||||
|
||||
if not function or not tool_call_id or not function.name:
|
||||
return None, None
|
||||
|
||||
error_exc = None
|
||||
result = None
|
||||
try:
|
||||
if function.name in ctx.mcp_tool_to_server:
|
||||
mcp_tool = ctx.mcp_tool_to_server[function.name]
|
||||
result = await invoke_mcp_tool(
|
||||
endpoint=mcp_tool.server_url,
|
||||
headers=mcp_tool.headers or {},
|
||||
tool_name=function.name,
|
||||
kwargs=json.loads(function.arguments) if function.arguments else {},
|
||||
)
|
||||
else:
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=function.name,
|
||||
kwargs=json.loads(function.arguments) if function.arguments else {},
|
||||
)
|
||||
except Exception as e:
|
||||
error_exc = e
|
||||
|
||||
if function.name in ctx.mcp_tool_to_server:
|
||||
from llama_stack.apis.agents.openai_responses import OpenAIResponseOutputMessageMCPCall
|
||||
|
||||
message = OpenAIResponseOutputMessageMCPCall(
|
||||
id=tool_call_id,
|
||||
arguments=function.arguments,
|
||||
name=function.name,
|
||||
server_label=ctx.mcp_tool_to_server[function.name].server_label,
|
||||
)
|
||||
if error_exc:
|
||||
message.error = str(error_exc)
|
||||
elif (result.error_code and result.error_code > 0) or result.error_message:
|
||||
message.error = f"Error (code {result.error_code}): {result.error_message}"
|
||||
elif result.content:
|
||||
message.output = interleaved_content_as_str(result.content)
|
||||
else:
|
||||
if function.name == "web_search":
|
||||
message = OpenAIResponseOutputMessageWebSearchToolCall(
|
||||
id=tool_call_id,
|
||||
status="completed",
|
||||
)
|
||||
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
|
||||
message.status = "failed"
|
||||
else:
|
||||
raise ValueError(f"Unknown tool {function.name} called")
|
||||
|
||||
input_message = None
|
||||
if result and result.content:
|
||||
if isinstance(result.content, str):
|
||||
content = result.content
|
||||
elif isinstance(result.content, list):
|
||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||
|
||||
content = []
|
||||
for item in result.content:
|
||||
if isinstance(item, TextContentItem):
|
||||
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
|
||||
elif isinstance(item, ImageContentItem):
|
||||
if item.image.data:
|
||||
url = f"data:image;base64,{item.image.data}"
|
||||
else:
|
||||
url = item.image.url
|
||||
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
|
||||
else:
|
||||
raise ValueError(f"Unknown result content type: {type(item)}")
|
||||
content.append(part)
|
||||
else:
|
||||
raise ValueError(f"Unknown result content type: {type(result.content)}")
|
||||
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
|
||||
|
||||
return message, input_message
|
||||
|
|
|
@ -4,53 +4,26 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import exceptiongroup
|
||||
import httpx
|
||||
from mcp import ClientSession
|
||||
from mcp import types as mcp_types
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, ImageContentItem, TextContentItem
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools
|
||||
|
||||
from .config import MCPProviderConfig
|
||||
|
||||
logger = get_logger(__name__, category="tools")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
|
||||
try:
|
||||
async with sse_client(endpoint, headers=headers) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
yield session
|
||||
except BaseException as e:
|
||||
if isinstance(e, exceptiongroup.BaseExceptionGroup):
|
||||
for exc in e.exceptions:
|
||||
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401:
|
||||
raise AuthenticationRequiredError(exc) from exc
|
||||
elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401:
|
||||
raise AuthenticationRequiredError(e) from e
|
||||
|
||||
raise
|
||||
|
||||
|
||||
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
|
||||
self.config = config
|
||||
|
@ -64,32 +37,8 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
|
|||
# this endpoint should be retrieved by getting the tool group right?
|
||||
if mcp_endpoint is None:
|
||||
raise ValueError("mcp_endpoint is required")
|
||||
|
||||
headers = await self.get_headers_from_request(mcp_endpoint.uri)
|
||||
tools = []
|
||||
async with sse_client_wrapper(mcp_endpoint.uri, headers) as session:
|
||||
tools_result = await session.list_tools()
|
||||
for tool in tools_result.tools:
|
||||
parameters = []
|
||||
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name=param_name,
|
||||
parameter_type=param_schema.get("type", "string"),
|
||||
description=param_schema.get("description", ""),
|
||||
)
|
||||
)
|
||||
tools.append(
|
||||
ToolDef(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=parameters,
|
||||
metadata={
|
||||
"endpoint": mcp_endpoint.uri,
|
||||
},
|
||||
)
|
||||
)
|
||||
return ListToolDefsResponse(data=tools)
|
||||
return await list_mcp_tools(mcp_endpoint.uri, headers)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
tool = await self.tool_store.get_tool(tool_name)
|
||||
|
@ -100,23 +49,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
|
|||
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
|
||||
|
||||
headers = await self.get_headers_from_request(endpoint)
|
||||
async with sse_client_wrapper(endpoint, headers) as session:
|
||||
result = await session.call_tool(tool.identifier, kwargs)
|
||||
|
||||
content = []
|
||||
for item in result.content:
|
||||
if isinstance(item, mcp_types.TextContent):
|
||||
content.append(TextContentItem(text=item.text))
|
||||
elif isinstance(item, mcp_types.ImageContent):
|
||||
content.append(ImageContentItem(image=item.data))
|
||||
elif isinstance(item, mcp_types.EmbeddedResource):
|
||||
logger.warning(f"EmbeddedResource is not supported: {item}")
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {type(item)}")
|
||||
return ToolInvocationResult(
|
||||
content=content,
|
||||
error_code=1 if result.isError else 0,
|
||||
)
|
||||
return await invoke_mcp_tool(endpoint, headers, tool_name, kwargs)
|
||||
|
||||
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
|
||||
def canonicalize_uri(uri: str) -> str:
|
||||
|
@ -129,9 +62,5 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
|
|||
for uri, values in provider_data.mcp_headers.items():
|
||||
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
|
||||
continue
|
||||
for entry in values:
|
||||
parts = entry.split(":")
|
||||
if len(parts) == 2:
|
||||
k, v = parts
|
||||
headers[k.strip()] = v.strip()
|
||||
headers.update(convert_header_list_to_dict(values))
|
||||
return headers
|
||||
|
|
110
llama_stack/providers/utils/tools/mcp.py
Normal file
110
llama_stack/providers/utils/tools/mcp.py
Normal file
|
@ -0,0 +1,110 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
# for python < 3.11
|
||||
import exceptiongroup
|
||||
|
||||
BaseExceptionGroup = exceptiongroup.BaseExceptionGroup
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import httpx
|
||||
from mcp import ClientSession
|
||||
from mcp import types as mcp_types
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(__name__, category="tools")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
|
||||
try:
|
||||
async with sse_client(endpoint, headers=headers) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
yield session
|
||||
except BaseException as e:
|
||||
if isinstance(e, BaseExceptionGroup):
|
||||
for exc in e.exceptions:
|
||||
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401:
|
||||
raise AuthenticationRequiredError(exc) from exc
|
||||
elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401:
|
||||
raise AuthenticationRequiredError(e) from e
|
||||
|
||||
raise
|
||||
|
||||
|
||||
def convert_header_list_to_dict(header_list: list[str]) -> dict[str, str]:
|
||||
headers = {}
|
||||
for header in header_list:
|
||||
parts = header.split(":")
|
||||
if len(parts) == 2:
|
||||
k, v = parts
|
||||
headers[k.strip()] = v.strip()
|
||||
return headers
|
||||
|
||||
|
||||
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
|
||||
tools = []
|
||||
async with sse_client_wrapper(endpoint, headers) as session:
|
||||
tools_result = await session.list_tools()
|
||||
for tool in tools_result.tools:
|
||||
parameters = []
|
||||
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name=param_name,
|
||||
parameter_type=param_schema.get("type", "string"),
|
||||
description=param_schema.get("description", ""),
|
||||
)
|
||||
)
|
||||
tools.append(
|
||||
ToolDef(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=parameters,
|
||||
metadata={
|
||||
"endpoint": endpoint,
|
||||
},
|
||||
)
|
||||
)
|
||||
return ListToolDefsResponse(data=tools)
|
||||
|
||||
|
||||
async def invoke_mcp_tool(
|
||||
endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
async with sse_client_wrapper(endpoint, headers) as session:
|
||||
result = await session.call_tool(tool_name, kwargs)
|
||||
|
||||
content: list[InterleavedContentItem] = []
|
||||
for item in result.content:
|
||||
if isinstance(item, mcp_types.TextContent):
|
||||
content.append(TextContentItem(text=item.text))
|
||||
elif isinstance(item, mcp_types.ImageContent):
|
||||
content.append(ImageContentItem(image=item.data))
|
||||
elif isinstance(item, mcp_types.EmbeddedResource):
|
||||
logger.warning(f"EmbeddedResource is not supported: {item}")
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {type(item)}")
|
||||
return ToolInvocationResult(
|
||||
content=content,
|
||||
error_code=1 if result.isError else 0,
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue