mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
feat: enable MCP execution in Response implementation
This commit is contained in:
parent
a411029d7e
commit
5937d94da5
9 changed files with 728 additions and 174 deletions
128
docs/_static/llama-stack-spec.html
vendored
128
docs/_static/llama-stack-spec.html
vendored
|
@ -7218,15 +7218,15 @@
|
||||||
"OpenAIResponseOutputMessageFunctionToolCall": {
|
"OpenAIResponseOutputMessageFunctionToolCall": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"arguments": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"call_id": {
|
"call_id": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"name": {
|
"name": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
|
"arguments": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
"type": {
|
"type": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"const": "function_call",
|
"const": "function_call",
|
||||||
|
@ -7241,12 +7241,10 @@
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"arguments",
|
|
||||||
"call_id",
|
"call_id",
|
||||||
"name",
|
"name",
|
||||||
"type",
|
"arguments",
|
||||||
"id",
|
"type"
|
||||||
"status"
|
|
||||||
],
|
],
|
||||||
"title": "OpenAIResponseOutputMessageFunctionToolCall"
|
"title": "OpenAIResponseOutputMessageFunctionToolCall"
|
||||||
},
|
},
|
||||||
|
@ -7412,6 +7410,12 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall"
|
"$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPCall"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"discriminator": {
|
"discriminator": {
|
||||||
|
@ -7419,10 +7423,118 @@
|
||||||
"mapping": {
|
"mapping": {
|
||||||
"message": "#/components/schemas/OpenAIResponseMessage",
|
"message": "#/components/schemas/OpenAIResponseMessage",
|
||||||
"web_search_call": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall",
|
"web_search_call": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall",
|
||||||
"function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall"
|
"function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall",
|
||||||
|
"mcp_call": "#/components/schemas/OpenAIResponseOutputMessageMCPCall",
|
||||||
|
"mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"OpenAIResponseOutputMessageMCPCall": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "mcp_call",
|
||||||
|
"default": "mcp_call"
|
||||||
|
},
|
||||||
|
"arguments": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"name": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"server_label": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"output": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"id",
|
||||||
|
"type",
|
||||||
|
"arguments",
|
||||||
|
"name",
|
||||||
|
"server_label"
|
||||||
|
],
|
||||||
|
"title": "OpenAIResponseOutputMessageMCPCall"
|
||||||
|
},
|
||||||
|
"OpenAIResponseOutputMessageMCPListTools": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "mcp_list_tools",
|
||||||
|
"default": "mcp_list_tools"
|
||||||
|
},
|
||||||
|
"server_label": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"tools": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"name": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"description": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"input_schema",
|
||||||
|
"name"
|
||||||
|
],
|
||||||
|
"title": "MCPListToolsTool"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"id",
|
||||||
|
"type",
|
||||||
|
"server_label",
|
||||||
|
"tools"
|
||||||
|
],
|
||||||
|
"title": "OpenAIResponseOutputMessageMCPListTools"
|
||||||
|
},
|
||||||
"OpenAIResponseObjectStream": {
|
"OpenAIResponseObjectStream": {
|
||||||
"oneOf": [
|
"oneOf": [
|
||||||
{
|
{
|
||||||
|
|
81
docs/_static/llama-stack-spec.yaml
vendored
81
docs/_static/llama-stack-spec.yaml
vendored
|
@ -5075,12 +5075,12 @@ components:
|
||||||
"OpenAIResponseOutputMessageFunctionToolCall":
|
"OpenAIResponseOutputMessageFunctionToolCall":
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
arguments:
|
|
||||||
type: string
|
|
||||||
call_id:
|
call_id:
|
||||||
type: string
|
type: string
|
||||||
name:
|
name:
|
||||||
type: string
|
type: string
|
||||||
|
arguments:
|
||||||
|
type: string
|
||||||
type:
|
type:
|
||||||
type: string
|
type: string
|
||||||
const: function_call
|
const: function_call
|
||||||
|
@ -5091,12 +5091,10 @@ components:
|
||||||
type: string
|
type: string
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- arguments
|
|
||||||
- call_id
|
- call_id
|
||||||
- name
|
- name
|
||||||
|
- arguments
|
||||||
- type
|
- type
|
||||||
- id
|
|
||||||
- status
|
|
||||||
title: >-
|
title: >-
|
||||||
OpenAIResponseOutputMessageFunctionToolCall
|
OpenAIResponseOutputMessageFunctionToolCall
|
||||||
"OpenAIResponseOutputMessageWebSearchToolCall":
|
"OpenAIResponseOutputMessageWebSearchToolCall":
|
||||||
|
@ -5214,12 +5212,85 @@ components:
|
||||||
- $ref: '#/components/schemas/OpenAIResponseMessage'
|
- $ref: '#/components/schemas/OpenAIResponseMessage'
|
||||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
- $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
||||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
||||||
|
- $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
|
||||||
|
- $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
|
||||||
discriminator:
|
discriminator:
|
||||||
propertyName: type
|
propertyName: type
|
||||||
mapping:
|
mapping:
|
||||||
message: '#/components/schemas/OpenAIResponseMessage'
|
message: '#/components/schemas/OpenAIResponseMessage'
|
||||||
web_search_call: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
web_search_call: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
||||||
function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
||||||
|
mcp_call: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
|
||||||
|
mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
|
||||||
|
OpenAIResponseOutputMessageMCPCall:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
id:
|
||||||
|
type: string
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: mcp_call
|
||||||
|
default: mcp_call
|
||||||
|
arguments:
|
||||||
|
type: string
|
||||||
|
name:
|
||||||
|
type: string
|
||||||
|
server_label:
|
||||||
|
type: string
|
||||||
|
error:
|
||||||
|
type: string
|
||||||
|
output:
|
||||||
|
type: string
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- id
|
||||||
|
- type
|
||||||
|
- arguments
|
||||||
|
- name
|
||||||
|
- server_label
|
||||||
|
title: OpenAIResponseOutputMessageMCPCall
|
||||||
|
OpenAIResponseOutputMessageMCPListTools:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
id:
|
||||||
|
type: string
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: mcp_list_tools
|
||||||
|
default: mcp_list_tools
|
||||||
|
server_label:
|
||||||
|
type: string
|
||||||
|
tools:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
input_schema:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
|
name:
|
||||||
|
type: string
|
||||||
|
description:
|
||||||
|
type: string
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- input_schema
|
||||||
|
- name
|
||||||
|
title: MCPListToolsTool
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- id
|
||||||
|
- type
|
||||||
|
- server_label
|
||||||
|
- tools
|
||||||
|
title: OpenAIResponseOutputMessageMCPListTools
|
||||||
OpenAIResponseObjectStream:
|
OpenAIResponseObjectStream:
|
||||||
oneOf:
|
oneOf:
|
||||||
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
|
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
|
||||||
|
|
|
@ -10,6 +10,9 @@ from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
|
# NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably
|
||||||
|
# take their YAML and generate this file automatically. Their YAML is available.
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseError(BaseModel):
|
class OpenAIResponseError(BaseModel):
|
||||||
|
@ -79,16 +82,45 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
|
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
|
||||||
arguments: str
|
|
||||||
call_id: str
|
call_id: str
|
||||||
name: str
|
name: str
|
||||||
|
arguments: str
|
||||||
type: Literal["function_call"] = "function_call"
|
type: Literal["function_call"] = "function_call"
|
||||||
|
id: str | None = None
|
||||||
|
status: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseOutputMessageMCPCall(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
status: str
|
type: Literal["mcp_call"] = "mcp_call"
|
||||||
|
arguments: str
|
||||||
|
name: str
|
||||||
|
server_label: str
|
||||||
|
error: str | None = None
|
||||||
|
output: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class MCPListToolsTool(BaseModel):
|
||||||
|
input_schema: dict[str, Any]
|
||||||
|
name: str
|
||||||
|
description: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseOutputMessageMCPListTools(BaseModel):
|
||||||
|
id: str
|
||||||
|
type: Literal["mcp_list_tools"] = "mcp_list_tools"
|
||||||
|
server_label: str
|
||||||
|
tools: list[MCPListToolsTool]
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseOutput = Annotated[
|
OpenAIResponseOutput = Annotated[
|
||||||
OpenAIResponseMessage | OpenAIResponseOutputMessageWebSearchToolCall | OpenAIResponseOutputMessageFunctionToolCall,
|
OpenAIResponseMessage
|
||||||
|
| OpenAIResponseOutputMessageWebSearchToolCall
|
||||||
|
| OpenAIResponseOutputMessageFunctionToolCall
|
||||||
|
| OpenAIResponseOutputMessageMCPCall
|
||||||
|
| OpenAIResponseOutputMessageMCPListTools,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
||||||
|
|
|
@ -14,6 +14,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.agents import Order
|
from llama_stack.apis.agents import Order
|
||||||
from llama_stack.apis.agents.openai_responses import (
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
|
AllowedToolsFilter,
|
||||||
ListOpenAIResponseInputItem,
|
ListOpenAIResponseInputItem,
|
||||||
ListOpenAIResponseObject,
|
ListOpenAIResponseObject,
|
||||||
OpenAIResponseInput,
|
OpenAIResponseInput,
|
||||||
|
@ -22,7 +23,7 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseInputMessageContentImage,
|
OpenAIResponseInputMessageContentImage,
|
||||||
OpenAIResponseInputMessageContentText,
|
OpenAIResponseInputMessageContentText,
|
||||||
OpenAIResponseInputTool,
|
OpenAIResponseInputTool,
|
||||||
OpenAIResponseInputToolFunction,
|
OpenAIResponseInputToolMCP,
|
||||||
OpenAIResponseMessage,
|
OpenAIResponseMessage,
|
||||||
OpenAIResponseObject,
|
OpenAIResponseObject,
|
||||||
OpenAIResponseObjectStream,
|
OpenAIResponseObjectStream,
|
||||||
|
@ -51,11 +52,12 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIToolMessageParam,
|
OpenAIToolMessageParam,
|
||||||
OpenAIUserMessageParam,
|
OpenAIUserMessageParam,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.tools.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
from llama_stack.apis.tools.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||||
|
from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="openai_responses")
|
logger = get_logger(name=__name__, category="openai_responses")
|
||||||
|
|
||||||
|
@ -168,6 +170,15 @@ class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
||||||
response: OpenAIResponseObject
|
response: OpenAIResponseObject
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionContext(BaseModel):
|
||||||
|
model: str
|
||||||
|
messages: list[OpenAIMessageParam]
|
||||||
|
tools: list[ChatCompletionToolParam] | None = None
|
||||||
|
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
|
||||||
|
stream: bool
|
||||||
|
temperature: float | None
|
||||||
|
|
||||||
|
|
||||||
class OpenAIResponsesImpl:
|
class OpenAIResponsesImpl:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -255,13 +266,32 @@ class OpenAIResponsesImpl:
|
||||||
temperature: float | None = None,
|
temperature: float | None = None,
|
||||||
tools: list[OpenAIResponseInputTool] | None = None,
|
tools: list[OpenAIResponseInputTool] | None = None,
|
||||||
):
|
):
|
||||||
|
output_messages: list[OpenAIResponseOutput] = []
|
||||||
|
|
||||||
stream = False if stream is None else stream
|
stream = False if stream is None else stream
|
||||||
|
|
||||||
|
# Huge TODO: we need to run this in a loop, until morale improves
|
||||||
|
|
||||||
|
# Create context to run "chat completion"
|
||||||
input = await self._prepend_previous_response(input, previous_response_id)
|
input = await self._prepend_previous_response(input, previous_response_id)
|
||||||
messages = await _convert_response_input_to_chat_messages(input)
|
messages = await _convert_response_input_to_chat_messages(input)
|
||||||
await self._prepend_instructions(messages, instructions)
|
await self._prepend_instructions(messages, instructions)
|
||||||
chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None
|
chat_tools, mcp_tool_to_server, mcp_list_message = (
|
||||||
|
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {})
|
||||||
|
)
|
||||||
|
if mcp_list_message:
|
||||||
|
output_messages.append(mcp_list_message)
|
||||||
|
|
||||||
|
ctx = ChatCompletionContext(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
tools=chat_tools,
|
||||||
|
mcp_tool_to_server=mcp_tool_to_server,
|
||||||
|
stream=stream,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run inference
|
||||||
chat_response = await self.inference_api.openai_chat_completion(
|
chat_response = await self.inference_api.openai_chat_completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -270,6 +300,7 @@ class OpenAIResponsesImpl:
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Collect output
|
||||||
if stream:
|
if stream:
|
||||||
# TODO: refactor this into a separate method that handles streaming
|
# TODO: refactor this into a separate method that handles streaming
|
||||||
chat_response_id = ""
|
chat_response_id = ""
|
||||||
|
@ -328,11 +359,11 @@ class OpenAIResponsesImpl:
|
||||||
# dump and reload to map to our pydantic types
|
# dump and reload to map to our pydantic types
|
||||||
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
|
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
|
||||||
|
|
||||||
output_messages: list[OpenAIResponseOutput] = []
|
# Execute tool calls if any
|
||||||
for choice in chat_response.choices:
|
for choice in chat_response.choices:
|
||||||
if choice.message.tool_calls and tools:
|
if choice.message.tool_calls and tools:
|
||||||
# Assume if the first tool is a function, all tools are functions
|
# Assume if the first tool is a function, all tools are functions
|
||||||
if isinstance(tools[0], OpenAIResponseInputToolFunction):
|
if tools[0].type == "function":
|
||||||
for tool_call in choice.message.tool_calls:
|
for tool_call in choice.message.tool_calls:
|
||||||
output_messages.append(
|
output_messages.append(
|
||||||
OpenAIResponseOutputMessageFunctionToolCall(
|
OpenAIResponseOutputMessageFunctionToolCall(
|
||||||
|
@ -344,11 +375,12 @@ class OpenAIResponsesImpl:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output_messages.extend(
|
tool_messages = await self._execute_tool_and_return_final_output(choice, ctx)
|
||||||
await self._execute_tool_and_return_final_output(model, stream, choice, messages, temperature)
|
output_messages.extend(tool_messages)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
output_messages.append(await _convert_chat_choice_to_response_message(choice))
|
output_messages.append(await _convert_chat_choice_to_response_message(choice))
|
||||||
|
|
||||||
|
# Create response object
|
||||||
response = OpenAIResponseObject(
|
response = OpenAIResponseObject(
|
||||||
created_at=chat_response.created,
|
created_at=chat_response.created,
|
||||||
id=f"resp-{uuid.uuid4()}",
|
id=f"resp-{uuid.uuid4()}",
|
||||||
|
@ -359,9 +391,8 @@ class OpenAIResponsesImpl:
|
||||||
)
|
)
|
||||||
logger.debug(f"OpenAI Responses response: {response}")
|
logger.debug(f"OpenAI Responses response: {response}")
|
||||||
|
|
||||||
|
# Store response if requested
|
||||||
if store:
|
if store:
|
||||||
# Store in kvstore
|
|
||||||
|
|
||||||
new_input_id = f"msg_{uuid.uuid4()}"
|
new_input_id = f"msg_{uuid.uuid4()}"
|
||||||
if isinstance(input, str):
|
if isinstance(input, str):
|
||||||
# synthesize a message from the input string
|
# synthesize a message from the input string
|
||||||
|
@ -403,7 +434,36 @@ class OpenAIResponsesImpl:
|
||||||
|
|
||||||
async def _convert_response_tools_to_chat_tools(
|
async def _convert_response_tools_to_chat_tools(
|
||||||
self, tools: list[OpenAIResponseInputTool]
|
self, tools: list[OpenAIResponseInputTool]
|
||||||
) -> list[ChatCompletionToolParam]:
|
) -> tuple[
|
||||||
|
list[ChatCompletionToolParam],
|
||||||
|
dict[str, OpenAIResponseInputToolMCP],
|
||||||
|
OpenAIResponseOutput | None,
|
||||||
|
]:
|
||||||
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
|
MCPListToolsTool,
|
||||||
|
OpenAIResponseOutputMessageMCPListTools,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.tools.tools import Tool
|
||||||
|
|
||||||
|
mcp_tool_to_server = {}
|
||||||
|
|
||||||
|
def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
|
||||||
|
tool_def = ToolDefinition(
|
||||||
|
tool_name=tool_name,
|
||||||
|
description=tool.description,
|
||||||
|
parameters={
|
||||||
|
param.name: ToolParamDefinition(
|
||||||
|
param_type=param.parameter_type,
|
||||||
|
description=param.description,
|
||||||
|
required=param.required,
|
||||||
|
default=param.default,
|
||||||
|
)
|
||||||
|
for param in tool.parameters
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return convert_tooldef_to_openai_tool(tool_def)
|
||||||
|
|
||||||
|
mcp_list_message = None
|
||||||
chat_tools: list[ChatCompletionToolParam] = []
|
chat_tools: list[ChatCompletionToolParam] = []
|
||||||
for input_tool in tools:
|
for input_tool in tools:
|
||||||
# TODO: Handle other tool types
|
# TODO: Handle other tool types
|
||||||
|
@ -412,91 +472,95 @@ class OpenAIResponsesImpl:
|
||||||
elif input_tool.type == "web_search":
|
elif input_tool.type == "web_search":
|
||||||
tool_name = "web_search"
|
tool_name = "web_search"
|
||||||
tool = await self.tool_groups_api.get_tool(tool_name)
|
tool = await self.tool_groups_api.get_tool(tool_name)
|
||||||
tool_def = ToolDefinition(
|
if not tool:
|
||||||
tool_name=tool_name,
|
raise ValueError(f"Tool {tool_name} not found")
|
||||||
description=tool.description,
|
chat_tools.append(make_openai_tool(tool_name, tool))
|
||||||
parameters={
|
elif input_tool.type == "mcp":
|
||||||
param.name: ToolParamDefinition(
|
always_allowed = None
|
||||||
param_type=param.parameter_type,
|
never_allowed = None
|
||||||
description=param.description,
|
if input_tool.allowed_tools:
|
||||||
required=param.required,
|
if isinstance(input_tool.allowed_tools, list):
|
||||||
default=param.default,
|
always_allowed = input_tool.allowed_tools
|
||||||
)
|
elif isinstance(input_tool.allowed_tools, AllowedToolsFilter):
|
||||||
for param in tool.parameters
|
always_allowed = input_tool.allowed_tools.always
|
||||||
},
|
never_allowed = input_tool.allowed_tools.never
|
||||||
|
|
||||||
|
tool_defs = await list_mcp_tools(
|
||||||
|
endpoint=input_tool.server_url,
|
||||||
|
headers=convert_header_list_to_dict(input_tool.headers or []),
|
||||||
)
|
)
|
||||||
chat_tool = convert_tooldef_to_openai_tool(tool_def)
|
|
||||||
chat_tools.append(chat_tool)
|
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
||||||
|
id=f"mcp_list_{uuid.uuid4()}",
|
||||||
|
status="completed",
|
||||||
|
server_label=input_tool.server_label,
|
||||||
|
tools=[],
|
||||||
|
)
|
||||||
|
for t in tool_defs.data:
|
||||||
|
if never_allowed and t.name in never_allowed:
|
||||||
|
continue
|
||||||
|
if not always_allowed or t.name in always_allowed:
|
||||||
|
chat_tools.append(make_openai_tool(t.name, t))
|
||||||
|
if t.name in mcp_tool_to_server:
|
||||||
|
raise ValueError(f"Duplicate tool name {t.name} found for server {input_tool.server_label}")
|
||||||
|
mcp_tool_to_server[t.name] = input_tool
|
||||||
|
mcp_list_message.tools.append(
|
||||||
|
MCPListToolsTool(
|
||||||
|
name=t.name,
|
||||||
|
description=t.description,
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
p.name: {
|
||||||
|
"type": p.parameter_type,
|
||||||
|
"description": p.description,
|
||||||
|
}
|
||||||
|
for p in t.parameters
|
||||||
|
},
|
||||||
|
"required": [p.name for p in t.parameters if p.required],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
||||||
return chat_tools
|
return chat_tools, mcp_tool_to_server, mcp_list_message
|
||||||
|
|
||||||
async def _execute_tool_and_return_final_output(
|
async def _execute_tool_and_return_final_output(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
|
||||||
stream: bool,
|
|
||||||
choice: OpenAIChoice,
|
choice: OpenAIChoice,
|
||||||
messages: list[OpenAIMessageParam],
|
ctx: ChatCompletionContext,
|
||||||
temperature: float,
|
|
||||||
) -> list[OpenAIResponseOutput]:
|
) -> list[OpenAIResponseOutput]:
|
||||||
output_messages: list[OpenAIResponseOutput] = []
|
output_messages: list[OpenAIResponseOutput] = []
|
||||||
|
|
||||||
# If the choice is not an assistant message, we don't need to execute any tools
|
|
||||||
if not isinstance(choice.message, OpenAIAssistantMessageParam):
|
if not isinstance(choice.message, OpenAIAssistantMessageParam):
|
||||||
return output_messages
|
return output_messages
|
||||||
|
|
||||||
# If the assistant message doesn't have any tool calls, we don't need to execute any tools
|
|
||||||
if not choice.message.tool_calls:
|
if not choice.message.tool_calls:
|
||||||
return output_messages
|
return output_messages
|
||||||
|
|
||||||
# Copy the messages list to avoid mutating the original list
|
next_turn_messages = ctx.messages.copy()
|
||||||
messages = messages.copy()
|
|
||||||
|
|
||||||
# Add the assistant message with tool_calls response to the messages list
|
# Add the assistant message with tool_calls response to the messages list
|
||||||
messages.append(choice.message)
|
next_turn_messages.append(choice.message)
|
||||||
|
|
||||||
for tool_call in choice.message.tool_calls:
|
for tool_call in choice.message.tool_calls:
|
||||||
tool_call_id = tool_call.id
|
|
||||||
function = tool_call.function
|
|
||||||
|
|
||||||
# If for some reason the tool call doesn't have a function or id, we can't execute it
|
|
||||||
if not function or not tool_call_id:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# TODO: telemetry spans for tool calls
|
# TODO: telemetry spans for tool calls
|
||||||
result = await self._execute_tool_call(function)
|
tool_call_log, further_input = await self._execute_tool_call(tool_call, ctx)
|
||||||
|
if tool_call_log:
|
||||||
# Handle tool call failure
|
output_messages.append(tool_call_log)
|
||||||
if not result:
|
if further_input:
|
||||||
output_messages.append(
|
next_turn_messages.append(further_input)
|
||||||
OpenAIResponseOutputMessageWebSearchToolCall(
|
|
||||||
id=tool_call_id,
|
|
||||||
status="failed",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
output_messages.append(
|
|
||||||
OpenAIResponseOutputMessageWebSearchToolCall(
|
|
||||||
id=tool_call_id,
|
|
||||||
status="completed",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
result_content = ""
|
|
||||||
# TODO: handle other result content types and lists
|
|
||||||
if isinstance(result.content, str):
|
|
||||||
result_content = result.content
|
|
||||||
messages.append(OpenAIToolMessageParam(content=result_content, tool_call_id=tool_call_id))
|
|
||||||
|
|
||||||
tool_results_chat_response = await self.inference_api.openai_chat_completion(
|
tool_results_chat_response = await self.inference_api.openai_chat_completion(
|
||||||
model=model_id,
|
model=ctx.model,
|
||||||
messages=messages,
|
messages=next_turn_messages,
|
||||||
stream=stream,
|
stream=ctx.stream,
|
||||||
temperature=temperature,
|
temperature=ctx.temperature,
|
||||||
)
|
)
|
||||||
# type cast to appease mypy
|
# type cast to appease mypy: this is needed because we don't handle streaming properly :)
|
||||||
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
|
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
|
||||||
|
|
||||||
|
# Huge TODO: these are NOT the final outputs, we must keep the loop going
|
||||||
tool_final_outputs = [
|
tool_final_outputs = [
|
||||||
await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices
|
await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices
|
||||||
]
|
]
|
||||||
|
@ -506,15 +570,85 @@ class OpenAIResponsesImpl:
|
||||||
|
|
||||||
async def _execute_tool_call(
|
async def _execute_tool_call(
|
||||||
self,
|
self,
|
||||||
function: OpenAIChatCompletionToolCallFunction,
|
tool_call: OpenAIChatCompletionToolCall,
|
||||||
) -> ToolInvocationResult | None:
|
ctx: ChatCompletionContext,
|
||||||
if not function.name:
|
) -> tuple[OpenAIResponseOutput | None, OpenAIMessageParam | None]:
|
||||||
return None
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
function_args = json.loads(function.arguments) if function.arguments else {}
|
interleaved_content_as_str,
|
||||||
logger.info(f"executing tool call: {function.name} with args: {function_args}")
|
|
||||||
result = await self.tool_runtime_api.invoke_tool(
|
|
||||||
tool_name=function.name,
|
|
||||||
kwargs=function_args,
|
|
||||||
)
|
)
|
||||||
logger.debug(f"tool call {function.name} completed with result: {result}")
|
|
||||||
return result
|
tool_call_id = tool_call.id
|
||||||
|
function = tool_call.function
|
||||||
|
|
||||||
|
if not function or not tool_call_id or not function.name:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
error_exc = None
|
||||||
|
try:
|
||||||
|
if function.name in ctx.mcp_tool_to_server:
|
||||||
|
mcp_tool = ctx.mcp_tool_to_server[function.name]
|
||||||
|
result = await invoke_mcp_tool(
|
||||||
|
endpoint=mcp_tool.server_url,
|
||||||
|
headers=convert_header_list_to_dict(mcp_tool.headers or []),
|
||||||
|
tool_name=function.name,
|
||||||
|
kwargs=json.loads(function.arguments) if function.arguments else {},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = await self.tool_runtime_api.invoke_tool(
|
||||||
|
tool_name=function.name,
|
||||||
|
kwargs=json.loads(function.arguments) if function.arguments else {},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
error_exc = e
|
||||||
|
|
||||||
|
if function.name in ctx.mcp_tool_to_server:
|
||||||
|
from llama_stack.apis.agents.openai_responses import OpenAIResponseOutputMessageMCPCall
|
||||||
|
|
||||||
|
message = OpenAIResponseOutputMessageMCPCall(
|
||||||
|
id=tool_call_id,
|
||||||
|
arguments=function.arguments,
|
||||||
|
name=function.name,
|
||||||
|
server_label=ctx.mcp_tool_to_server[function.name].server_label,
|
||||||
|
)
|
||||||
|
if error_exc:
|
||||||
|
message.error = str(error_exc)
|
||||||
|
elif (result.error_code and result.error_code > 0) or result.error_message:
|
||||||
|
message.error = f"Error (code {result.error_code}): {result.error_message}"
|
||||||
|
elif result.content:
|
||||||
|
message.output = interleaved_content_as_str(result.content)
|
||||||
|
else:
|
||||||
|
if function.name == "web_search":
|
||||||
|
message = OpenAIResponseOutputMessageWebSearchToolCall(
|
||||||
|
id=tool_call_id,
|
||||||
|
status="completed",
|
||||||
|
)
|
||||||
|
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
|
||||||
|
message.status = "failed"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown tool {function.name} called")
|
||||||
|
|
||||||
|
input_message = None
|
||||||
|
if result.content:
|
||||||
|
if isinstance(result.content, str):
|
||||||
|
content = result.content
|
||||||
|
elif isinstance(result.content, list):
|
||||||
|
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||||
|
|
||||||
|
content = []
|
||||||
|
for item in result.content:
|
||||||
|
if isinstance(item, TextContentItem):
|
||||||
|
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
|
||||||
|
elif isinstance(item, ImageContentItem):
|
||||||
|
if item.image.data:
|
||||||
|
url = f"data:image;base64,{item.image.data}"
|
||||||
|
else:
|
||||||
|
url = item.image.url
|
||||||
|
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown result content type: {type(item)}")
|
||||||
|
content.append(part)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown result content type: {type(result.content)}")
|
||||||
|
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
|
||||||
|
|
||||||
|
return message, input_message
|
||||||
|
|
|
@ -4,53 +4,26 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import exceptiongroup
|
from llama_stack.apis.common.content_types import URL
|
||||||
import httpx
|
|
||||||
from mcp import ClientSession
|
|
||||||
from mcp import types as mcp_types
|
|
||||||
from mcp.client.sse import sse_client
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL, ImageContentItem, TextContentItem
|
|
||||||
from llama_stack.apis.datatypes import Api
|
from llama_stack.apis.datatypes import Api
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolDefsResponse,
|
||||||
ToolDef,
|
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError
|
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools
|
||||||
|
|
||||||
from .config import MCPProviderConfig
|
from .config import MCPProviderConfig
|
||||||
|
|
||||||
logger = get_logger(__name__, category="tools")
|
logger = get_logger(__name__, category="tools")
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
|
|
||||||
try:
|
|
||||||
async with sse_client(endpoint, headers=headers) as streams:
|
|
||||||
async with ClientSession(*streams) as session:
|
|
||||||
await session.initialize()
|
|
||||||
yield session
|
|
||||||
except BaseException as e:
|
|
||||||
if isinstance(e, exceptiongroup.BaseExceptionGroup):
|
|
||||||
for exc in e.exceptions:
|
|
||||||
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401:
|
|
||||||
raise AuthenticationRequiredError(exc) from exc
|
|
||||||
elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401:
|
|
||||||
raise AuthenticationRequiredError(e) from e
|
|
||||||
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||||
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
|
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -64,32 +37,8 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
|
||||||
# this endpoint should be retrieved by getting the tool group right?
|
# this endpoint should be retrieved by getting the tool group right?
|
||||||
if mcp_endpoint is None:
|
if mcp_endpoint is None:
|
||||||
raise ValueError("mcp_endpoint is required")
|
raise ValueError("mcp_endpoint is required")
|
||||||
|
|
||||||
headers = await self.get_headers_from_request(mcp_endpoint.uri)
|
headers = await self.get_headers_from_request(mcp_endpoint.uri)
|
||||||
tools = []
|
return await list_mcp_tools(mcp_endpoint.uri, headers)
|
||||||
async with sse_client_wrapper(mcp_endpoint.uri, headers) as session:
|
|
||||||
tools_result = await session.list_tools()
|
|
||||||
for tool in tools_result.tools:
|
|
||||||
parameters = []
|
|
||||||
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
|
|
||||||
parameters.append(
|
|
||||||
ToolParameter(
|
|
||||||
name=param_name,
|
|
||||||
parameter_type=param_schema.get("type", "string"),
|
|
||||||
description=param_schema.get("description", ""),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
tools.append(
|
|
||||||
ToolDef(
|
|
||||||
name=tool.name,
|
|
||||||
description=tool.description,
|
|
||||||
parameters=parameters,
|
|
||||||
metadata={
|
|
||||||
"endpoint": mcp_endpoint.uri,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return ListToolDefsResponse(data=tools)
|
|
||||||
|
|
||||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||||
tool = await self.tool_store.get_tool(tool_name)
|
tool = await self.tool_store.get_tool(tool_name)
|
||||||
|
@ -100,23 +49,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
|
||||||
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
|
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
|
||||||
|
|
||||||
headers = await self.get_headers_from_request(endpoint)
|
headers = await self.get_headers_from_request(endpoint)
|
||||||
async with sse_client_wrapper(endpoint, headers) as session:
|
return await invoke_mcp_tool(endpoint, headers, tool_name, kwargs)
|
||||||
result = await session.call_tool(tool.identifier, kwargs)
|
|
||||||
|
|
||||||
content = []
|
|
||||||
for item in result.content:
|
|
||||||
if isinstance(item, mcp_types.TextContent):
|
|
||||||
content.append(TextContentItem(text=item.text))
|
|
||||||
elif isinstance(item, mcp_types.ImageContent):
|
|
||||||
content.append(ImageContentItem(image=item.data))
|
|
||||||
elif isinstance(item, mcp_types.EmbeddedResource):
|
|
||||||
logger.warning(f"EmbeddedResource is not supported: {item}")
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown content type: {type(item)}")
|
|
||||||
return ToolInvocationResult(
|
|
||||||
content=content,
|
|
||||||
error_code=1 if result.isError else 0,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
|
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
|
||||||
def canonicalize_uri(uri: str) -> str:
|
def canonicalize_uri(uri: str) -> str:
|
||||||
|
@ -129,9 +62,5 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
|
||||||
for uri, values in provider_data.mcp_headers.items():
|
for uri, values in provider_data.mcp_headers.items():
|
||||||
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
|
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
|
||||||
continue
|
continue
|
||||||
for entry in values:
|
headers.update(convert_header_list_to_dict(values))
|
||||||
parts = entry.split(":")
|
|
||||||
if len(parts) == 2:
|
|
||||||
k, v = parts
|
|
||||||
headers[k.strip()] = v.strip()
|
|
||||||
return headers
|
return headers
|
||||||
|
|
103
llama_stack/providers/utils/tools/mcp.py
Normal file
103
llama_stack/providers/utils/tools/mcp.py
Normal file
|
@ -0,0 +1,103 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import exceptiongroup
|
||||||
|
import httpx
|
||||||
|
from mcp import ClientSession
|
||||||
|
from mcp import types as mcp_types
|
||||||
|
from mcp.client.sse import sse_client
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
|
ListToolDefsResponse,
|
||||||
|
ToolDef,
|
||||||
|
ToolInvocationResult,
|
||||||
|
ToolParameter,
|
||||||
|
)
|
||||||
|
from llama_stack.distribution.datatypes import AuthenticationRequiredError
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__, category="tools")
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
|
||||||
|
try:
|
||||||
|
async with sse_client(endpoint, headers=headers) as streams:
|
||||||
|
async with ClientSession(*streams) as session:
|
||||||
|
await session.initialize()
|
||||||
|
yield session
|
||||||
|
except BaseException as e:
|
||||||
|
if isinstance(e, exceptiongroup.BaseExceptionGroup):
|
||||||
|
for exc in e.exceptions:
|
||||||
|
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401:
|
||||||
|
raise AuthenticationRequiredError(exc) from exc
|
||||||
|
elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401:
|
||||||
|
raise AuthenticationRequiredError(e) from e
|
||||||
|
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def convert_header_list_to_dict(header_list: list[str]) -> dict[str, str]:
|
||||||
|
headers = {}
|
||||||
|
for header in header_list:
|
||||||
|
parts = header.split(":")
|
||||||
|
if len(parts) == 2:
|
||||||
|
k, v = parts
|
||||||
|
headers[k.strip()] = v.strip()
|
||||||
|
return headers
|
||||||
|
|
||||||
|
|
||||||
|
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
|
||||||
|
tools = []
|
||||||
|
async with sse_client_wrapper(endpoint, headers) as session:
|
||||||
|
tools_result = await session.list_tools()
|
||||||
|
for tool in tools_result.tools:
|
||||||
|
parameters = []
|
||||||
|
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
|
||||||
|
parameters.append(
|
||||||
|
ToolParameter(
|
||||||
|
name=param_name,
|
||||||
|
parameter_type=param_schema.get("type", "string"),
|
||||||
|
description=param_schema.get("description", ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tools.append(
|
||||||
|
ToolDef(
|
||||||
|
name=tool.name,
|
||||||
|
description=tool.description,
|
||||||
|
parameters=parameters,
|
||||||
|
metadata={
|
||||||
|
"endpoint": endpoint,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return ListToolDefsResponse(data=tools)
|
||||||
|
|
||||||
|
|
||||||
|
async def invoke_mcp_tool(
|
||||||
|
endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any]
|
||||||
|
) -> ToolInvocationResult:
|
||||||
|
async with sse_client_wrapper(endpoint, headers) as session:
|
||||||
|
result = await session.call_tool(tool_name, kwargs)
|
||||||
|
|
||||||
|
content: list[InterleavedContentItem] = []
|
||||||
|
for item in result.content:
|
||||||
|
if isinstance(item, mcp_types.TextContent):
|
||||||
|
content.append(TextContentItem(text=item.text))
|
||||||
|
elif isinstance(item, mcp_types.ImageContent):
|
||||||
|
content.append(ImageContentItem(image=item.data))
|
||||||
|
elif isinstance(item, mcp_types.EmbeddedResource):
|
||||||
|
logger.warning(f"EmbeddedResource is not supported: {item}")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown content type: {type(item)}")
|
||||||
|
return ToolInvocationResult(
|
||||||
|
content=content,
|
||||||
|
error_code=1 if result.isError else 0,
|
||||||
|
)
|
116
tests/common/mcp.py
Normal file
116
tests/common/mcp.py
Normal file
|
@ -0,0 +1,116 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# we want the mcp server to be authenticated OR not, depends
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def make_mcp_server(required_auth_token: str | None = None):
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import uvicorn
|
||||||
|
from mcp import types
|
||||||
|
from mcp.server.fastmcp import Context, FastMCP
|
||||||
|
from mcp.server.sse import SseServerTransport
|
||||||
|
from starlette.applications import Starlette
|
||||||
|
from starlette.responses import Response
|
||||||
|
from starlette.routing import Mount, Route
|
||||||
|
|
||||||
|
server = FastMCP("FastMCP Test Server")
|
||||||
|
|
||||||
|
@server.tool()
|
||||||
|
async def greet_everyone(
|
||||||
|
url: str, ctx: Context
|
||||||
|
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||||
|
return [types.TextContent(type="text", text="Hello, world!")]
|
||||||
|
|
||||||
|
@server.tool()
|
||||||
|
async def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
|
||||||
|
"""
|
||||||
|
Returns the boiling point of a liquid in Celcius or Fahrenheit.
|
||||||
|
|
||||||
|
:param liquid_name: The name of the liquid
|
||||||
|
:param celcius: Whether to return the boiling point in Celcius
|
||||||
|
:return: The boiling point of the liquid in Celcius or Fahrenheit
|
||||||
|
"""
|
||||||
|
if liquid_name.lower() == "polyjuice":
|
||||||
|
if celcius:
|
||||||
|
return -100
|
||||||
|
else:
|
||||||
|
return -212
|
||||||
|
else:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
sse = SseServerTransport("/messages/")
|
||||||
|
|
||||||
|
async def handle_sse(request):
|
||||||
|
from starlette.exceptions import HTTPException
|
||||||
|
|
||||||
|
auth_header = request.headers.get("Authorization")
|
||||||
|
auth_token = None
|
||||||
|
if auth_header and auth_header.startswith("Bearer "):
|
||||||
|
auth_token = auth_header.split(" ")[1]
|
||||||
|
|
||||||
|
if required_auth_token and auth_token != required_auth_token:
|
||||||
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||||
|
|
||||||
|
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
|
||||||
|
await server._mcp_server.run(
|
||||||
|
streams[0],
|
||||||
|
streams[1],
|
||||||
|
server._mcp_server.create_initialization_options(),
|
||||||
|
)
|
||||||
|
return Response()
|
||||||
|
|
||||||
|
app = Starlette(
|
||||||
|
routes=[
|
||||||
|
Route("/sse", endpoint=handle_sse),
|
||||||
|
Mount("/messages/", app=sse.handle_post_message),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_open_port():
|
||||||
|
import socket
|
||||||
|
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||||
|
sock.bind(("", 0))
|
||||||
|
return sock.getsockname()[1]
|
||||||
|
|
||||||
|
port = get_open_port()
|
||||||
|
|
||||||
|
config = uvicorn.Config(app, host="0.0.0.0", port=port)
|
||||||
|
server_instance = uvicorn.Server(config)
|
||||||
|
app.state.uvicorn_server = server_instance
|
||||||
|
|
||||||
|
def run_server():
|
||||||
|
server_instance.run()
|
||||||
|
|
||||||
|
# Start the server in a new thread
|
||||||
|
server_thread = threading.Thread(target=run_server, daemon=True)
|
||||||
|
server_thread.start()
|
||||||
|
|
||||||
|
# Polling until the server is ready
|
||||||
|
timeout = 10
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
server_url = f"http://localhost:{port}/sse"
|
||||||
|
while time.time() - start_time < timeout:
|
||||||
|
try:
|
||||||
|
response = httpx.get(server_url)
|
||||||
|
if response.status_code in [200, 401]:
|
||||||
|
break
|
||||||
|
except httpx.RequestError:
|
||||||
|
pass
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
yield {"server_url": server_url}
|
||||||
|
|
||||||
|
# Tell server to exit
|
||||||
|
server_instance.should_exit = True
|
||||||
|
server_thread.join(timeout=5)
|
|
@ -31,6 +31,20 @@ test_response_web_search:
|
||||||
search_context_size: "low"
|
search_context_size: "low"
|
||||||
output: "128"
|
output: "128"
|
||||||
|
|
||||||
|
test_response_mcp_tool:
|
||||||
|
test_name: test_response_mcp_tool
|
||||||
|
test_params:
|
||||||
|
case:
|
||||||
|
- case_id: "boiling_point_tool"
|
||||||
|
input: "What is the boiling point of polyjuice?"
|
||||||
|
tools:
|
||||||
|
- type: mcp
|
||||||
|
server_label: "localmcp"
|
||||||
|
server_url: "<FILLED_BY_TEST_RUNNER>"
|
||||||
|
headers:
|
||||||
|
Authorization: "Bearer test-token"
|
||||||
|
output: "Hello, world!"
|
||||||
|
|
||||||
test_response_custom_tool:
|
test_response_custom_tool:
|
||||||
test_name: test_response_custom_tool
|
test_name: test_response_custom_tool
|
||||||
test_params:
|
test_params:
|
||||||
|
|
|
@ -4,9 +4,11 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from tests.common.mcp import make_mcp_server
|
||||||
from tests.verifications.openai_api.fixtures.fixtures import (
|
from tests.verifications.openai_api.fixtures.fixtures import (
|
||||||
case_id_generator,
|
case_id_generator,
|
||||||
get_base_test_name,
|
get_base_test_name,
|
||||||
|
@ -124,6 +126,47 @@ def test_response_non_streaming_web_search(request, openai_client, model, provid
|
||||||
assert case["output"].lower() in response.output_text.lower().strip()
|
assert case["output"].lower() in response.output_text.lower().strip()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"case",
|
||||||
|
responses_test_cases["test_response_mcp_tool"]["test_params"]["case"],
|
||||||
|
ids=case_id_generator,
|
||||||
|
)
|
||||||
|
def test_response_non_streaming_mcp_tool(request, openai_client, model, provider, verification_config, case):
|
||||||
|
test_name_base = get_base_test_name(request)
|
||||||
|
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||||
|
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||||
|
|
||||||
|
with make_mcp_server() as mcp_server_info:
|
||||||
|
tools = case["tools"]
|
||||||
|
for tool in tools:
|
||||||
|
if tool["type"] == "mcp":
|
||||||
|
tool["server_url"] = mcp_server_info["server_url"]
|
||||||
|
|
||||||
|
response = openai_client.responses.create(
|
||||||
|
model=model,
|
||||||
|
input=case["input"],
|
||||||
|
tools=tools,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
assert len(response.output) >= 3
|
||||||
|
list_tools = response.output[0]
|
||||||
|
assert list_tools.type == "mcp_list_tools"
|
||||||
|
assert list_tools.server_label == "localmcp"
|
||||||
|
assert len(list_tools.tools) == 2
|
||||||
|
assert {t["name"] for t in list_tools.tools} == {"get_boiling_point", "greet_everyone"}
|
||||||
|
|
||||||
|
call = response.output[1]
|
||||||
|
assert call.type == "mcp_call"
|
||||||
|
assert call.name == "get_boiling_point"
|
||||||
|
assert json.loads(call.arguments) == {"liquid_name": "polyjuice", "celcius": True}
|
||||||
|
assert call.error is None
|
||||||
|
assert "-100" in call.output
|
||||||
|
|
||||||
|
message = response.output[2]
|
||||||
|
text_content = message.content[0].text
|
||||||
|
assert "boiling point" in text_content.lower()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"case",
|
"case",
|
||||||
responses_test_cases["test_response_custom_tool"]["test_params"]["case"],
|
responses_test_cases["test_response_custom_tool"]["test_params"]["case"],
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue