forked from phoenix-oss/llama-stack-mirror
feat: enable MCP execution in Responses impl (#2240)
## Test Plan ``` pytest -s -v 'tests/verifications/openai_api/test_responses.py' \ --provider=stack:together --model meta-llama/Llama-4-Scout-17B-16E-Instruct ```
This commit is contained in:
parent
66f09f24ed
commit
3faf1e4a79
15 changed files with 865 additions and 382 deletions
4
.github/workflows/integration-tests.yml
vendored
4
.github/workflows/integration-tests.yml
vendored
|
@ -24,7 +24,7 @@ jobs:
|
||||||
matrix:
|
matrix:
|
||||||
# Listing tests manually since some of them currently fail
|
# Listing tests manually since some of them currently fail
|
||||||
# TODO: generate matrix list from tests/integration when fixed
|
# TODO: generate matrix list from tests/integration when fixed
|
||||||
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers]
|
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers, tool_runtime]
|
||||||
client-type: [library, http]
|
client-type: [library, http]
|
||||||
fail-fast: false # we want to run all tests regardless of failure
|
fail-fast: false # we want to run all tests regardless of failure
|
||||||
|
|
||||||
|
@ -90,7 +90,7 @@ jobs:
|
||||||
else
|
else
|
||||||
stack_config="http://localhost:8321"
|
stack_config="http://localhost:8321"
|
||||||
fi
|
fi
|
||||||
uv run pytest -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \
|
uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \
|
||||||
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
|
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
|
||||||
--text-model="meta-llama/Llama-3.2-3B-Instruct" \
|
--text-model="meta-llama/Llama-3.2-3B-Instruct" \
|
||||||
--embedding-model=all-MiniLM-L6-v2
|
--embedding-model=all-MiniLM-L6-v2
|
||||||
|
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -6,6 +6,7 @@ dev_requirements.txt
|
||||||
build
|
build
|
||||||
.DS_Store
|
.DS_Store
|
||||||
llama_stack/configs/*
|
llama_stack/configs/*
|
||||||
|
.cursor/
|
||||||
xcuserdata/
|
xcuserdata/
|
||||||
*.hmap
|
*.hmap
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|
128
docs/_static/llama-stack-spec.html
vendored
128
docs/_static/llama-stack-spec.html
vendored
|
@ -7218,15 +7218,15 @@
|
||||||
"OpenAIResponseOutputMessageFunctionToolCall": {
|
"OpenAIResponseOutputMessageFunctionToolCall": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"arguments": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"call_id": {
|
"call_id": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"name": {
|
"name": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
|
"arguments": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
"type": {
|
"type": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"const": "function_call",
|
"const": "function_call",
|
||||||
|
@ -7241,12 +7241,10 @@
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"arguments",
|
|
||||||
"call_id",
|
"call_id",
|
||||||
"name",
|
"name",
|
||||||
"type",
|
"arguments",
|
||||||
"id",
|
"type"
|
||||||
"status"
|
|
||||||
],
|
],
|
||||||
"title": "OpenAIResponseOutputMessageFunctionToolCall"
|
"title": "OpenAIResponseOutputMessageFunctionToolCall"
|
||||||
},
|
},
|
||||||
|
@ -7412,6 +7410,12 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall"
|
"$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPCall"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"discriminator": {
|
"discriminator": {
|
||||||
|
@ -7419,10 +7423,118 @@
|
||||||
"mapping": {
|
"mapping": {
|
||||||
"message": "#/components/schemas/OpenAIResponseMessage",
|
"message": "#/components/schemas/OpenAIResponseMessage",
|
||||||
"web_search_call": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall",
|
"web_search_call": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall",
|
||||||
"function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall"
|
"function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall",
|
||||||
|
"mcp_call": "#/components/schemas/OpenAIResponseOutputMessageMCPCall",
|
||||||
|
"mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"OpenAIResponseOutputMessageMCPCall": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "mcp_call",
|
||||||
|
"default": "mcp_call"
|
||||||
|
},
|
||||||
|
"arguments": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"name": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"server_label": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"output": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"id",
|
||||||
|
"type",
|
||||||
|
"arguments",
|
||||||
|
"name",
|
||||||
|
"server_label"
|
||||||
|
],
|
||||||
|
"title": "OpenAIResponseOutputMessageMCPCall"
|
||||||
|
},
|
||||||
|
"OpenAIResponseOutputMessageMCPListTools": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "mcp_list_tools",
|
||||||
|
"default": "mcp_list_tools"
|
||||||
|
},
|
||||||
|
"server_label": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"tools": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"name": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"description": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"input_schema",
|
||||||
|
"name"
|
||||||
|
],
|
||||||
|
"title": "MCPListToolsTool"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"id",
|
||||||
|
"type",
|
||||||
|
"server_label",
|
||||||
|
"tools"
|
||||||
|
],
|
||||||
|
"title": "OpenAIResponseOutputMessageMCPListTools"
|
||||||
|
},
|
||||||
"OpenAIResponseObjectStream": {
|
"OpenAIResponseObjectStream": {
|
||||||
"oneOf": [
|
"oneOf": [
|
||||||
{
|
{
|
||||||
|
|
81
docs/_static/llama-stack-spec.yaml
vendored
81
docs/_static/llama-stack-spec.yaml
vendored
|
@ -5075,12 +5075,12 @@ components:
|
||||||
"OpenAIResponseOutputMessageFunctionToolCall":
|
"OpenAIResponseOutputMessageFunctionToolCall":
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
arguments:
|
|
||||||
type: string
|
|
||||||
call_id:
|
call_id:
|
||||||
type: string
|
type: string
|
||||||
name:
|
name:
|
||||||
type: string
|
type: string
|
||||||
|
arguments:
|
||||||
|
type: string
|
||||||
type:
|
type:
|
||||||
type: string
|
type: string
|
||||||
const: function_call
|
const: function_call
|
||||||
|
@ -5091,12 +5091,10 @@ components:
|
||||||
type: string
|
type: string
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- arguments
|
|
||||||
- call_id
|
- call_id
|
||||||
- name
|
- name
|
||||||
|
- arguments
|
||||||
- type
|
- type
|
||||||
- id
|
|
||||||
- status
|
|
||||||
title: >-
|
title: >-
|
||||||
OpenAIResponseOutputMessageFunctionToolCall
|
OpenAIResponseOutputMessageFunctionToolCall
|
||||||
"OpenAIResponseOutputMessageWebSearchToolCall":
|
"OpenAIResponseOutputMessageWebSearchToolCall":
|
||||||
|
@ -5214,12 +5212,85 @@ components:
|
||||||
- $ref: '#/components/schemas/OpenAIResponseMessage'
|
- $ref: '#/components/schemas/OpenAIResponseMessage'
|
||||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
- $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
||||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
||||||
|
- $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
|
||||||
|
- $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
|
||||||
discriminator:
|
discriminator:
|
||||||
propertyName: type
|
propertyName: type
|
||||||
mapping:
|
mapping:
|
||||||
message: '#/components/schemas/OpenAIResponseMessage'
|
message: '#/components/schemas/OpenAIResponseMessage'
|
||||||
web_search_call: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
web_search_call: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
||||||
function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
||||||
|
mcp_call: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
|
||||||
|
mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
|
||||||
|
OpenAIResponseOutputMessageMCPCall:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
id:
|
||||||
|
type: string
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: mcp_call
|
||||||
|
default: mcp_call
|
||||||
|
arguments:
|
||||||
|
type: string
|
||||||
|
name:
|
||||||
|
type: string
|
||||||
|
server_label:
|
||||||
|
type: string
|
||||||
|
error:
|
||||||
|
type: string
|
||||||
|
output:
|
||||||
|
type: string
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- id
|
||||||
|
- type
|
||||||
|
- arguments
|
||||||
|
- name
|
||||||
|
- server_label
|
||||||
|
title: OpenAIResponseOutputMessageMCPCall
|
||||||
|
OpenAIResponseOutputMessageMCPListTools:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
id:
|
||||||
|
type: string
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: mcp_list_tools
|
||||||
|
default: mcp_list_tools
|
||||||
|
server_label:
|
||||||
|
type: string
|
||||||
|
tools:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
input_schema:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
|
name:
|
||||||
|
type: string
|
||||||
|
description:
|
||||||
|
type: string
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- input_schema
|
||||||
|
- name
|
||||||
|
title: MCPListToolsTool
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- id
|
||||||
|
- type
|
||||||
|
- server_label
|
||||||
|
- tools
|
||||||
|
title: OpenAIResponseOutputMessageMCPListTools
|
||||||
OpenAIResponseObjectStream:
|
OpenAIResponseObjectStream:
|
||||||
oneOf:
|
oneOf:
|
||||||
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
|
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
|
||||||
|
|
|
@ -10,6 +10,9 @@ from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
|
# NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably
|
||||||
|
# take their YAML and generate this file automatically. Their YAML is available.
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseError(BaseModel):
|
class OpenAIResponseError(BaseModel):
|
||||||
|
@ -79,16 +82,45 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
|
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
|
||||||
arguments: str
|
|
||||||
call_id: str
|
call_id: str
|
||||||
name: str
|
name: str
|
||||||
|
arguments: str
|
||||||
type: Literal["function_call"] = "function_call"
|
type: Literal["function_call"] = "function_call"
|
||||||
|
id: str | None = None
|
||||||
|
status: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseOutputMessageMCPCall(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
status: str
|
type: Literal["mcp_call"] = "mcp_call"
|
||||||
|
arguments: str
|
||||||
|
name: str
|
||||||
|
server_label: str
|
||||||
|
error: str | None = None
|
||||||
|
output: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class MCPListToolsTool(BaseModel):
|
||||||
|
input_schema: dict[str, Any]
|
||||||
|
name: str
|
||||||
|
description: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseOutputMessageMCPListTools(BaseModel):
|
||||||
|
id: str
|
||||||
|
type: Literal["mcp_list_tools"] = "mcp_list_tools"
|
||||||
|
server_label: str
|
||||||
|
tools: list[MCPListToolsTool]
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseOutput = Annotated[
|
OpenAIResponseOutput = Annotated[
|
||||||
OpenAIResponseMessage | OpenAIResponseOutputMessageWebSearchToolCall | OpenAIResponseOutputMessageFunctionToolCall,
|
OpenAIResponseMessage
|
||||||
|
| OpenAIResponseOutputMessageWebSearchToolCall
|
||||||
|
| OpenAIResponseOutputMessageFunctionToolCall
|
||||||
|
| OpenAIResponseOutputMessageMCPCall
|
||||||
|
| OpenAIResponseOutputMessageMCPListTools,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
||||||
|
|
|
@ -14,6 +14,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.agents import Order
|
from llama_stack.apis.agents import Order
|
||||||
from llama_stack.apis.agents.openai_responses import (
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
|
AllowedToolsFilter,
|
||||||
ListOpenAIResponseInputItem,
|
ListOpenAIResponseInputItem,
|
||||||
ListOpenAIResponseObject,
|
ListOpenAIResponseObject,
|
||||||
OpenAIResponseInput,
|
OpenAIResponseInput,
|
||||||
|
@ -22,7 +23,7 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseInputMessageContentImage,
|
OpenAIResponseInputMessageContentImage,
|
||||||
OpenAIResponseInputMessageContentText,
|
OpenAIResponseInputMessageContentText,
|
||||||
OpenAIResponseInputTool,
|
OpenAIResponseInputTool,
|
||||||
OpenAIResponseInputToolFunction,
|
OpenAIResponseInputToolMCP,
|
||||||
OpenAIResponseMessage,
|
OpenAIResponseMessage,
|
||||||
OpenAIResponseObject,
|
OpenAIResponseObject,
|
||||||
OpenAIResponseObjectStream,
|
OpenAIResponseObjectStream,
|
||||||
|
@ -51,11 +52,12 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIToolMessageParam,
|
OpenAIToolMessageParam,
|
||||||
OpenAIUserMessageParam,
|
OpenAIUserMessageParam,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.tools.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
from llama_stack.apis.tools.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||||
|
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="openai_responses")
|
logger = get_logger(name=__name__, category="openai_responses")
|
||||||
|
|
||||||
|
@ -168,6 +170,15 @@ class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
||||||
response: OpenAIResponseObject
|
response: OpenAIResponseObject
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionContext(BaseModel):
|
||||||
|
model: str
|
||||||
|
messages: list[OpenAIMessageParam]
|
||||||
|
tools: list[ChatCompletionToolParam] | None = None
|
||||||
|
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
|
||||||
|
stream: bool
|
||||||
|
temperature: float | None
|
||||||
|
|
||||||
|
|
||||||
class OpenAIResponsesImpl:
|
class OpenAIResponsesImpl:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -255,13 +266,32 @@ class OpenAIResponsesImpl:
|
||||||
temperature: float | None = None,
|
temperature: float | None = None,
|
||||||
tools: list[OpenAIResponseInputTool] | None = None,
|
tools: list[OpenAIResponseInputTool] | None = None,
|
||||||
):
|
):
|
||||||
|
output_messages: list[OpenAIResponseOutput] = []
|
||||||
|
|
||||||
stream = False if stream is None else stream
|
stream = False if stream is None else stream
|
||||||
|
|
||||||
|
# Huge TODO: we need to run this in a loop, until morale improves
|
||||||
|
|
||||||
|
# Create context to run "chat completion"
|
||||||
input = await self._prepend_previous_response(input, previous_response_id)
|
input = await self._prepend_previous_response(input, previous_response_id)
|
||||||
messages = await _convert_response_input_to_chat_messages(input)
|
messages = await _convert_response_input_to_chat_messages(input)
|
||||||
await self._prepend_instructions(messages, instructions)
|
await self._prepend_instructions(messages, instructions)
|
||||||
chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None
|
chat_tools, mcp_tool_to_server, mcp_list_message = (
|
||||||
|
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
|
||||||
|
)
|
||||||
|
if mcp_list_message:
|
||||||
|
output_messages.append(mcp_list_message)
|
||||||
|
|
||||||
|
ctx = ChatCompletionContext(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
tools=chat_tools,
|
||||||
|
mcp_tool_to_server=mcp_tool_to_server,
|
||||||
|
stream=stream,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run inference
|
||||||
chat_response = await self.inference_api.openai_chat_completion(
|
chat_response = await self.inference_api.openai_chat_completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -270,6 +300,7 @@ class OpenAIResponsesImpl:
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Collect output
|
||||||
if stream:
|
if stream:
|
||||||
# TODO: refactor this into a separate method that handles streaming
|
# TODO: refactor this into a separate method that handles streaming
|
||||||
chat_response_id = ""
|
chat_response_id = ""
|
||||||
|
@ -328,11 +359,11 @@ class OpenAIResponsesImpl:
|
||||||
# dump and reload to map to our pydantic types
|
# dump and reload to map to our pydantic types
|
||||||
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
|
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
|
||||||
|
|
||||||
output_messages: list[OpenAIResponseOutput] = []
|
# Execute tool calls if any
|
||||||
for choice in chat_response.choices:
|
for choice in chat_response.choices:
|
||||||
if choice.message.tool_calls and tools:
|
if choice.message.tool_calls and tools:
|
||||||
# Assume if the first tool is a function, all tools are functions
|
# Assume if the first tool is a function, all tools are functions
|
||||||
if isinstance(tools[0], OpenAIResponseInputToolFunction):
|
if tools[0].type == "function":
|
||||||
for tool_call in choice.message.tool_calls:
|
for tool_call in choice.message.tool_calls:
|
||||||
output_messages.append(
|
output_messages.append(
|
||||||
OpenAIResponseOutputMessageFunctionToolCall(
|
OpenAIResponseOutputMessageFunctionToolCall(
|
||||||
|
@ -344,11 +375,12 @@ class OpenAIResponsesImpl:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output_messages.extend(
|
tool_messages = await self._execute_tool_and_return_final_output(choice, ctx)
|
||||||
await self._execute_tool_and_return_final_output(model, stream, choice, messages, temperature)
|
output_messages.extend(tool_messages)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
output_messages.append(await _convert_chat_choice_to_response_message(choice))
|
output_messages.append(await _convert_chat_choice_to_response_message(choice))
|
||||||
|
|
||||||
|
# Create response object
|
||||||
response = OpenAIResponseObject(
|
response = OpenAIResponseObject(
|
||||||
created_at=chat_response.created,
|
created_at=chat_response.created,
|
||||||
id=f"resp-{uuid.uuid4()}",
|
id=f"resp-{uuid.uuid4()}",
|
||||||
|
@ -359,9 +391,8 @@ class OpenAIResponsesImpl:
|
||||||
)
|
)
|
||||||
logger.debug(f"OpenAI Responses response: {response}")
|
logger.debug(f"OpenAI Responses response: {response}")
|
||||||
|
|
||||||
|
# Store response if requested
|
||||||
if store:
|
if store:
|
||||||
# Store in kvstore
|
|
||||||
|
|
||||||
new_input_id = f"msg_{uuid.uuid4()}"
|
new_input_id = f"msg_{uuid.uuid4()}"
|
||||||
if isinstance(input, str):
|
if isinstance(input, str):
|
||||||
# synthesize a message from the input string
|
# synthesize a message from the input string
|
||||||
|
@ -403,15 +434,20 @@ class OpenAIResponsesImpl:
|
||||||
|
|
||||||
async def _convert_response_tools_to_chat_tools(
|
async def _convert_response_tools_to_chat_tools(
|
||||||
self, tools: list[OpenAIResponseInputTool]
|
self, tools: list[OpenAIResponseInputTool]
|
||||||
) -> list[ChatCompletionToolParam]:
|
) -> tuple[
|
||||||
chat_tools: list[ChatCompletionToolParam] = []
|
list[ChatCompletionToolParam],
|
||||||
for input_tool in tools:
|
dict[str, OpenAIResponseInputToolMCP],
|
||||||
# TODO: Handle other tool types
|
OpenAIResponseOutput | None,
|
||||||
if input_tool.type == "function":
|
]:
|
||||||
chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
elif input_tool.type == "web_search":
|
MCPListToolsTool,
|
||||||
tool_name = "web_search"
|
OpenAIResponseOutputMessageMCPListTools,
|
||||||
tool = await self.tool_groups_api.get_tool(tool_name)
|
)
|
||||||
|
from llama_stack.apis.tools.tools import Tool
|
||||||
|
|
||||||
|
mcp_tool_to_server = {}
|
||||||
|
|
||||||
|
def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
|
||||||
tool_def = ToolDefinition(
|
tool_def = ToolDefinition(
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
description=tool.description,
|
description=tool.description,
|
||||||
|
@ -425,78 +461,106 @@ class OpenAIResponsesImpl:
|
||||||
for param in tool.parameters
|
for param in tool.parameters
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
chat_tool = convert_tooldef_to_openai_tool(tool_def)
|
return convert_tooldef_to_openai_tool(tool_def)
|
||||||
chat_tools.append(chat_tool)
|
|
||||||
|
mcp_list_message = None
|
||||||
|
chat_tools: list[ChatCompletionToolParam] = []
|
||||||
|
for input_tool in tools:
|
||||||
|
# TODO: Handle other tool types
|
||||||
|
if input_tool.type == "function":
|
||||||
|
chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
|
||||||
|
elif input_tool.type == "web_search":
|
||||||
|
tool_name = "web_search"
|
||||||
|
tool = await self.tool_groups_api.get_tool(tool_name)
|
||||||
|
if not tool:
|
||||||
|
raise ValueError(f"Tool {tool_name} not found")
|
||||||
|
chat_tools.append(make_openai_tool(tool_name, tool))
|
||||||
|
elif input_tool.type == "mcp":
|
||||||
|
always_allowed = None
|
||||||
|
never_allowed = None
|
||||||
|
if input_tool.allowed_tools:
|
||||||
|
if isinstance(input_tool.allowed_tools, list):
|
||||||
|
always_allowed = input_tool.allowed_tools
|
||||||
|
elif isinstance(input_tool.allowed_tools, AllowedToolsFilter):
|
||||||
|
always_allowed = input_tool.allowed_tools.always
|
||||||
|
never_allowed = input_tool.allowed_tools.never
|
||||||
|
|
||||||
|
tool_defs = await list_mcp_tools(
|
||||||
|
endpoint=input_tool.server_url,
|
||||||
|
headers=input_tool.headers or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
||||||
|
id=f"mcp_list_{uuid.uuid4()}",
|
||||||
|
status="completed",
|
||||||
|
server_label=input_tool.server_label,
|
||||||
|
tools=[],
|
||||||
|
)
|
||||||
|
for t in tool_defs.data:
|
||||||
|
if never_allowed and t.name in never_allowed:
|
||||||
|
continue
|
||||||
|
if not always_allowed or t.name in always_allowed:
|
||||||
|
chat_tools.append(make_openai_tool(t.name, t))
|
||||||
|
if t.name in mcp_tool_to_server:
|
||||||
|
raise ValueError(f"Duplicate tool name {t.name} found for server {input_tool.server_label}")
|
||||||
|
mcp_tool_to_server[t.name] = input_tool
|
||||||
|
mcp_list_message.tools.append(
|
||||||
|
MCPListToolsTool(
|
||||||
|
name=t.name,
|
||||||
|
description=t.description,
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
p.name: {
|
||||||
|
"type": p.parameter_type,
|
||||||
|
"description": p.description,
|
||||||
|
}
|
||||||
|
for p in t.parameters
|
||||||
|
},
|
||||||
|
"required": [p.name for p in t.parameters if p.required],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
||||||
return chat_tools
|
return chat_tools, mcp_tool_to_server, mcp_list_message
|
||||||
|
|
||||||
async def _execute_tool_and_return_final_output(
|
async def _execute_tool_and_return_final_output(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
|
||||||
stream: bool,
|
|
||||||
choice: OpenAIChoice,
|
choice: OpenAIChoice,
|
||||||
messages: list[OpenAIMessageParam],
|
ctx: ChatCompletionContext,
|
||||||
temperature: float,
|
|
||||||
) -> list[OpenAIResponseOutput]:
|
) -> list[OpenAIResponseOutput]:
|
||||||
output_messages: list[OpenAIResponseOutput] = []
|
output_messages: list[OpenAIResponseOutput] = []
|
||||||
|
|
||||||
# If the choice is not an assistant message, we don't need to execute any tools
|
|
||||||
if not isinstance(choice.message, OpenAIAssistantMessageParam):
|
if not isinstance(choice.message, OpenAIAssistantMessageParam):
|
||||||
return output_messages
|
return output_messages
|
||||||
|
|
||||||
# If the assistant message doesn't have any tool calls, we don't need to execute any tools
|
|
||||||
if not choice.message.tool_calls:
|
if not choice.message.tool_calls:
|
||||||
return output_messages
|
return output_messages
|
||||||
|
|
||||||
# Copy the messages list to avoid mutating the original list
|
next_turn_messages = ctx.messages.copy()
|
||||||
messages = messages.copy()
|
|
||||||
|
|
||||||
# Add the assistant message with tool_calls response to the messages list
|
# Add the assistant message with tool_calls response to the messages list
|
||||||
messages.append(choice.message)
|
next_turn_messages.append(choice.message)
|
||||||
|
|
||||||
for tool_call in choice.message.tool_calls:
|
for tool_call in choice.message.tool_calls:
|
||||||
tool_call_id = tool_call.id
|
|
||||||
function = tool_call.function
|
|
||||||
|
|
||||||
# If for some reason the tool call doesn't have a function or id, we can't execute it
|
|
||||||
if not function or not tool_call_id:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# TODO: telemetry spans for tool calls
|
# TODO: telemetry spans for tool calls
|
||||||
result = await self._execute_tool_call(function)
|
tool_call_log, further_input = await self._execute_tool_call(tool_call, ctx)
|
||||||
|
if tool_call_log:
|
||||||
# Handle tool call failure
|
output_messages.append(tool_call_log)
|
||||||
if not result:
|
if further_input:
|
||||||
output_messages.append(
|
next_turn_messages.append(further_input)
|
||||||
OpenAIResponseOutputMessageWebSearchToolCall(
|
|
||||||
id=tool_call_id,
|
|
||||||
status="failed",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
output_messages.append(
|
|
||||||
OpenAIResponseOutputMessageWebSearchToolCall(
|
|
||||||
id=tool_call_id,
|
|
||||||
status="completed",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
result_content = ""
|
|
||||||
# TODO: handle other result content types and lists
|
|
||||||
if isinstance(result.content, str):
|
|
||||||
result_content = result.content
|
|
||||||
messages.append(OpenAIToolMessageParam(content=result_content, tool_call_id=tool_call_id))
|
|
||||||
|
|
||||||
tool_results_chat_response = await self.inference_api.openai_chat_completion(
|
tool_results_chat_response = await self.inference_api.openai_chat_completion(
|
||||||
model=model_id,
|
model=ctx.model,
|
||||||
messages=messages,
|
messages=next_turn_messages,
|
||||||
stream=stream,
|
stream=ctx.stream,
|
||||||
temperature=temperature,
|
temperature=ctx.temperature,
|
||||||
)
|
)
|
||||||
# type cast to appease mypy
|
# type cast to appease mypy: this is needed because we don't handle streaming properly :)
|
||||||
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
|
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
|
||||||
|
|
||||||
|
# Huge TODO: these are NOT the final outputs, we must keep the loop going
|
||||||
tool_final_outputs = [
|
tool_final_outputs = [
|
||||||
await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices
|
await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices
|
||||||
]
|
]
|
||||||
|
@ -506,15 +570,86 @@ class OpenAIResponsesImpl:
|
||||||
|
|
||||||
async def _execute_tool_call(
|
async def _execute_tool_call(
|
||||||
self,
|
self,
|
||||||
function: OpenAIChatCompletionToolCallFunction,
|
tool_call: OpenAIChatCompletionToolCall,
|
||||||
) -> ToolInvocationResult | None:
|
ctx: ChatCompletionContext,
|
||||||
if not function.name:
|
) -> tuple[OpenAIResponseOutput | None, OpenAIMessageParam | None]:
|
||||||
return None
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
function_args = json.loads(function.arguments) if function.arguments else {}
|
interleaved_content_as_str,
|
||||||
logger.info(f"executing tool call: {function.name} with args: {function_args}")
|
)
|
||||||
|
|
||||||
|
tool_call_id = tool_call.id
|
||||||
|
function = tool_call.function
|
||||||
|
|
||||||
|
if not function or not tool_call_id or not function.name:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
error_exc = None
|
||||||
|
result = None
|
||||||
|
try:
|
||||||
|
if function.name in ctx.mcp_tool_to_server:
|
||||||
|
mcp_tool = ctx.mcp_tool_to_server[function.name]
|
||||||
|
result = await invoke_mcp_tool(
|
||||||
|
endpoint=mcp_tool.server_url,
|
||||||
|
headers=mcp_tool.headers or {},
|
||||||
|
tool_name=function.name,
|
||||||
|
kwargs=json.loads(function.arguments) if function.arguments else {},
|
||||||
|
)
|
||||||
|
else:
|
||||||
result = await self.tool_runtime_api.invoke_tool(
|
result = await self.tool_runtime_api.invoke_tool(
|
||||||
tool_name=function.name,
|
tool_name=function.name,
|
||||||
kwargs=function_args,
|
kwargs=json.loads(function.arguments) if function.arguments else {},
|
||||||
)
|
)
|
||||||
logger.debug(f"tool call {function.name} completed with result: {result}")
|
except Exception as e:
|
||||||
return result
|
error_exc = e
|
||||||
|
|
||||||
|
if function.name in ctx.mcp_tool_to_server:
|
||||||
|
from llama_stack.apis.agents.openai_responses import OpenAIResponseOutputMessageMCPCall
|
||||||
|
|
||||||
|
message = OpenAIResponseOutputMessageMCPCall(
|
||||||
|
id=tool_call_id,
|
||||||
|
arguments=function.arguments,
|
||||||
|
name=function.name,
|
||||||
|
server_label=ctx.mcp_tool_to_server[function.name].server_label,
|
||||||
|
)
|
||||||
|
if error_exc:
|
||||||
|
message.error = str(error_exc)
|
||||||
|
elif (result.error_code and result.error_code > 0) or result.error_message:
|
||||||
|
message.error = f"Error (code {result.error_code}): {result.error_message}"
|
||||||
|
elif result.content:
|
||||||
|
message.output = interleaved_content_as_str(result.content)
|
||||||
|
else:
|
||||||
|
if function.name == "web_search":
|
||||||
|
message = OpenAIResponseOutputMessageWebSearchToolCall(
|
||||||
|
id=tool_call_id,
|
||||||
|
status="completed",
|
||||||
|
)
|
||||||
|
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
|
||||||
|
message.status = "failed"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown tool {function.name} called")
|
||||||
|
|
||||||
|
input_message = None
|
||||||
|
if result and result.content:
|
||||||
|
if isinstance(result.content, str):
|
||||||
|
content = result.content
|
||||||
|
elif isinstance(result.content, list):
|
||||||
|
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||||
|
|
||||||
|
content = []
|
||||||
|
for item in result.content:
|
||||||
|
if isinstance(item, TextContentItem):
|
||||||
|
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
|
||||||
|
elif isinstance(item, ImageContentItem):
|
||||||
|
if item.image.data:
|
||||||
|
url = f"data:image;base64,{item.image.data}"
|
||||||
|
else:
|
||||||
|
url = item.image.url
|
||||||
|
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown result content type: {type(item)}")
|
||||||
|
content.append(part)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown result content type: {type(result.content)}")
|
||||||
|
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
|
||||||
|
|
||||||
|
return message, input_message
|
||||||
|
|
|
@ -4,53 +4,26 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import exceptiongroup
|
from llama_stack.apis.common.content_types import URL
|
||||||
import httpx
|
|
||||||
from mcp import ClientSession
|
|
||||||
from mcp import types as mcp_types
|
|
||||||
from mcp.client.sse import sse_client
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL, ImageContentItem, TextContentItem
|
|
||||||
from llama_stack.apis.datatypes import Api
|
from llama_stack.apis.datatypes import Api
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolDefsResponse,
|
||||||
ToolDef,
|
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError
|
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools
|
||||||
|
|
||||||
from .config import MCPProviderConfig
|
from .config import MCPProviderConfig
|
||||||
|
|
||||||
logger = get_logger(__name__, category="tools")
|
logger = get_logger(__name__, category="tools")
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
|
|
||||||
try:
|
|
||||||
async with sse_client(endpoint, headers=headers) as streams:
|
|
||||||
async with ClientSession(*streams) as session:
|
|
||||||
await session.initialize()
|
|
||||||
yield session
|
|
||||||
except BaseException as e:
|
|
||||||
if isinstance(e, exceptiongroup.BaseExceptionGroup):
|
|
||||||
for exc in e.exceptions:
|
|
||||||
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401:
|
|
||||||
raise AuthenticationRequiredError(exc) from exc
|
|
||||||
elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401:
|
|
||||||
raise AuthenticationRequiredError(e) from e
|
|
||||||
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||||
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
|
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -64,32 +37,8 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
|
||||||
# this endpoint should be retrieved by getting the tool group right?
|
# this endpoint should be retrieved by getting the tool group right?
|
||||||
if mcp_endpoint is None:
|
if mcp_endpoint is None:
|
||||||
raise ValueError("mcp_endpoint is required")
|
raise ValueError("mcp_endpoint is required")
|
||||||
|
|
||||||
headers = await self.get_headers_from_request(mcp_endpoint.uri)
|
headers = await self.get_headers_from_request(mcp_endpoint.uri)
|
||||||
tools = []
|
return await list_mcp_tools(mcp_endpoint.uri, headers)
|
||||||
async with sse_client_wrapper(mcp_endpoint.uri, headers) as session:
|
|
||||||
tools_result = await session.list_tools()
|
|
||||||
for tool in tools_result.tools:
|
|
||||||
parameters = []
|
|
||||||
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
|
|
||||||
parameters.append(
|
|
||||||
ToolParameter(
|
|
||||||
name=param_name,
|
|
||||||
parameter_type=param_schema.get("type", "string"),
|
|
||||||
description=param_schema.get("description", ""),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
tools.append(
|
|
||||||
ToolDef(
|
|
||||||
name=tool.name,
|
|
||||||
description=tool.description,
|
|
||||||
parameters=parameters,
|
|
||||||
metadata={
|
|
||||||
"endpoint": mcp_endpoint.uri,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return ListToolDefsResponse(data=tools)
|
|
||||||
|
|
||||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||||
tool = await self.tool_store.get_tool(tool_name)
|
tool = await self.tool_store.get_tool(tool_name)
|
||||||
|
@ -100,23 +49,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
|
||||||
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
|
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
|
||||||
|
|
||||||
headers = await self.get_headers_from_request(endpoint)
|
headers = await self.get_headers_from_request(endpoint)
|
||||||
async with sse_client_wrapper(endpoint, headers) as session:
|
return await invoke_mcp_tool(endpoint, headers, tool_name, kwargs)
|
||||||
result = await session.call_tool(tool.identifier, kwargs)
|
|
||||||
|
|
||||||
content = []
|
|
||||||
for item in result.content:
|
|
||||||
if isinstance(item, mcp_types.TextContent):
|
|
||||||
content.append(TextContentItem(text=item.text))
|
|
||||||
elif isinstance(item, mcp_types.ImageContent):
|
|
||||||
content.append(ImageContentItem(image=item.data))
|
|
||||||
elif isinstance(item, mcp_types.EmbeddedResource):
|
|
||||||
logger.warning(f"EmbeddedResource is not supported: {item}")
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown content type: {type(item)}")
|
|
||||||
return ToolInvocationResult(
|
|
||||||
content=content,
|
|
||||||
error_code=1 if result.isError else 0,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
|
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
|
||||||
def canonicalize_uri(uri: str) -> str:
|
def canonicalize_uri(uri: str) -> str:
|
||||||
|
@ -129,9 +62,5 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
|
||||||
for uri, values in provider_data.mcp_headers.items():
|
for uri, values in provider_data.mcp_headers.items():
|
||||||
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
|
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
|
||||||
continue
|
continue
|
||||||
for entry in values:
|
headers.update(convert_header_list_to_dict(values))
|
||||||
parts = entry.split(":")
|
|
||||||
if len(parts) == 2:
|
|
||||||
k, v = parts
|
|
||||||
headers[k.strip()] = v.strip()
|
|
||||||
return headers
|
return headers
|
||||||
|
|
110
llama_stack/providers/utils/tools/mcp.py
Normal file
110
llama_stack/providers/utils/tools/mcp.py
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
try:
|
||||||
|
# for python < 3.11
|
||||||
|
import exceptiongroup
|
||||||
|
|
||||||
|
BaseExceptionGroup = exceptiongroup.BaseExceptionGroup
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from mcp import ClientSession
|
||||||
|
from mcp import types as mcp_types
|
||||||
|
from mcp.client.sse import sse_client
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
|
ListToolDefsResponse,
|
||||||
|
ToolDef,
|
||||||
|
ToolInvocationResult,
|
||||||
|
ToolParameter,
|
||||||
|
)
|
||||||
|
from llama_stack.distribution.datatypes import AuthenticationRequiredError
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__, category="tools")
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
|
||||||
|
try:
|
||||||
|
async with sse_client(endpoint, headers=headers) as streams:
|
||||||
|
async with ClientSession(*streams) as session:
|
||||||
|
await session.initialize()
|
||||||
|
yield session
|
||||||
|
except BaseException as e:
|
||||||
|
if isinstance(e, BaseExceptionGroup):
|
||||||
|
for exc in e.exceptions:
|
||||||
|
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401:
|
||||||
|
raise AuthenticationRequiredError(exc) from exc
|
||||||
|
elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401:
|
||||||
|
raise AuthenticationRequiredError(e) from e
|
||||||
|
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def convert_header_list_to_dict(header_list: list[str]) -> dict[str, str]:
|
||||||
|
headers = {}
|
||||||
|
for header in header_list:
|
||||||
|
parts = header.split(":")
|
||||||
|
if len(parts) == 2:
|
||||||
|
k, v = parts
|
||||||
|
headers[k.strip()] = v.strip()
|
||||||
|
return headers
|
||||||
|
|
||||||
|
|
||||||
|
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
|
||||||
|
tools = []
|
||||||
|
async with sse_client_wrapper(endpoint, headers) as session:
|
||||||
|
tools_result = await session.list_tools()
|
||||||
|
for tool in tools_result.tools:
|
||||||
|
parameters = []
|
||||||
|
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
|
||||||
|
parameters.append(
|
||||||
|
ToolParameter(
|
||||||
|
name=param_name,
|
||||||
|
parameter_type=param_schema.get("type", "string"),
|
||||||
|
description=param_schema.get("description", ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tools.append(
|
||||||
|
ToolDef(
|
||||||
|
name=tool.name,
|
||||||
|
description=tool.description,
|
||||||
|
parameters=parameters,
|
||||||
|
metadata={
|
||||||
|
"endpoint": endpoint,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return ListToolDefsResponse(data=tools)
|
||||||
|
|
||||||
|
|
||||||
|
async def invoke_mcp_tool(
|
||||||
|
endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any]
|
||||||
|
) -> ToolInvocationResult:
|
||||||
|
async with sse_client_wrapper(endpoint, headers) as session:
|
||||||
|
result = await session.call_tool(tool_name, kwargs)
|
||||||
|
|
||||||
|
content: list[InterleavedContentItem] = []
|
||||||
|
for item in result.content:
|
||||||
|
if isinstance(item, mcp_types.TextContent):
|
||||||
|
content.append(TextContentItem(text=item.text))
|
||||||
|
elif isinstance(item, mcp_types.ImageContent):
|
||||||
|
content.append(ImageContentItem(image=item.data))
|
||||||
|
elif isinstance(item, mcp_types.EmbeddedResource):
|
||||||
|
logger.warning(f"EmbeddedResource is not supported: {item}")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown content type: {type(item)}")
|
||||||
|
return ToolInvocationResult(
|
||||||
|
content=content,
|
||||||
|
error_code=1 if result.isError else 0,
|
||||||
|
)
|
|
@ -67,6 +67,7 @@ unit = [
|
||||||
"aiosqlite",
|
"aiosqlite",
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
|
"mcp",
|
||||||
"chardet",
|
"chardet",
|
||||||
"qdrant-client",
|
"qdrant-client",
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
"opentelemetry-exporter-otlp-proto-http",
|
||||||
|
|
145
tests/common/mcp.py
Normal file
145
tests/common/mcp.py
Normal file
|
@ -0,0 +1,145 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# we want the mcp server to be authenticated OR not, depends
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
# Unfortunately the toolgroup id must be tied to the tool names because the registry
|
||||||
|
# indexes on both toolgroups and tools independently (and not jointly). That really
|
||||||
|
# needs to be fixed.
|
||||||
|
MCP_TOOLGROUP_ID = "mcp::localmcp"
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def make_mcp_server(required_auth_token: str | None = None):
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import uvicorn
|
||||||
|
from mcp import types
|
||||||
|
from mcp.server.fastmcp import Context, FastMCP
|
||||||
|
from mcp.server.sse import SseServerTransport
|
||||||
|
from starlette.applications import Starlette
|
||||||
|
from starlette.responses import Response
|
||||||
|
from starlette.routing import Mount, Route
|
||||||
|
|
||||||
|
server = FastMCP("FastMCP Test Server", log_level="WARNING")
|
||||||
|
|
||||||
|
@server.tool()
|
||||||
|
async def greet_everyone(
|
||||||
|
url: str, ctx: Context
|
||||||
|
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||||
|
return [types.TextContent(type="text", text="Hello, world!")]
|
||||||
|
|
||||||
|
@server.tool()
|
||||||
|
async def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
|
||||||
|
"""
|
||||||
|
Returns the boiling point of a liquid in Celcius or Fahrenheit.
|
||||||
|
|
||||||
|
:param liquid_name: The name of the liquid
|
||||||
|
:param celcius: Whether to return the boiling point in Celcius
|
||||||
|
:return: The boiling point of the liquid in Celcius or Fahrenheit
|
||||||
|
"""
|
||||||
|
if liquid_name.lower() == "polyjuice":
|
||||||
|
if celcius:
|
||||||
|
return -100
|
||||||
|
else:
|
||||||
|
return -212
|
||||||
|
else:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
sse = SseServerTransport("/messages/")
|
||||||
|
|
||||||
|
async def handle_sse(request):
|
||||||
|
from starlette.exceptions import HTTPException
|
||||||
|
|
||||||
|
auth_header = request.headers.get("Authorization")
|
||||||
|
auth_token = None
|
||||||
|
if auth_header and auth_header.startswith("Bearer "):
|
||||||
|
auth_token = auth_header.split(" ")[1]
|
||||||
|
|
||||||
|
if required_auth_token and auth_token != required_auth_token:
|
||||||
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||||
|
|
||||||
|
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
|
||||||
|
await server._mcp_server.run(
|
||||||
|
streams[0],
|
||||||
|
streams[1],
|
||||||
|
server._mcp_server.create_initialization_options(),
|
||||||
|
)
|
||||||
|
return Response()
|
||||||
|
|
||||||
|
app = Starlette(
|
||||||
|
routes=[
|
||||||
|
Route("/sse", endpoint=handle_sse),
|
||||||
|
Mount("/messages/", app=sse.handle_post_message),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_open_port():
|
||||||
|
import socket
|
||||||
|
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||||
|
sock.bind(("", 0))
|
||||||
|
return sock.getsockname()[1]
|
||||||
|
|
||||||
|
port = get_open_port()
|
||||||
|
|
||||||
|
# make uvicorn logs be less verbose
|
||||||
|
config = uvicorn.Config(app, host="0.0.0.0", port=port, log_level="warning")
|
||||||
|
server_instance = uvicorn.Server(config)
|
||||||
|
app.state.uvicorn_server = server_instance
|
||||||
|
|
||||||
|
def run_server():
|
||||||
|
server_instance.run()
|
||||||
|
|
||||||
|
# Start the server in a new thread
|
||||||
|
server_thread = threading.Thread(target=run_server, daemon=True)
|
||||||
|
server_thread.start()
|
||||||
|
|
||||||
|
# Polling until the server is ready
|
||||||
|
timeout = 10
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
server_url = f"http://localhost:{port}/sse"
|
||||||
|
while time.time() - start_time < timeout:
|
||||||
|
try:
|
||||||
|
response = httpx.get(server_url)
|
||||||
|
if response.status_code in [200, 401]:
|
||||||
|
break
|
||||||
|
except httpx.RequestError:
|
||||||
|
pass
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield {"server_url": server_url}
|
||||||
|
finally:
|
||||||
|
print("Telling SSE server to exit")
|
||||||
|
server_instance.should_exit = True
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
# Force shutdown if still running
|
||||||
|
if server_thread.is_alive():
|
||||||
|
try:
|
||||||
|
if hasattr(server_instance, "servers") and server_instance.servers:
|
||||||
|
for srv in server_instance.servers:
|
||||||
|
srv.close()
|
||||||
|
|
||||||
|
# Wait for graceful shutdown
|
||||||
|
server_thread.join(timeout=3)
|
||||||
|
if server_thread.is_alive():
|
||||||
|
print("Warning: Server thread still alive after shutdown attempt")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during server shutdown: {e}")
|
||||||
|
|
||||||
|
# CRITICAL: Reset SSE global state to prevent event loop contamination
|
||||||
|
# Reset the SSE AppStatus singleton that stores anyio.Event objects
|
||||||
|
from sse_starlette.sse import AppStatus
|
||||||
|
|
||||||
|
AppStatus.should_exit = False
|
||||||
|
AppStatus.should_exit_event = None
|
||||||
|
print("SSE server exited")
|
|
@ -5,117 +5,43 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import socket
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import mcp.types as types
|
|
||||||
import pytest
|
import pytest
|
||||||
import uvicorn
|
|
||||||
from llama_stack_client import Agent
|
from llama_stack_client import Agent
|
||||||
from mcp.server.fastmcp import Context, FastMCP
|
|
||||||
from mcp.server.sse import SseServerTransport
|
|
||||||
from starlette.applications import Starlette
|
|
||||||
from starlette.exceptions import HTTPException
|
|
||||||
from starlette.responses import Response
|
|
||||||
from starlette.routing import Mount, Route
|
|
||||||
|
|
||||||
from llama_stack import LlamaStackAsLibraryClient
|
from llama_stack import LlamaStackAsLibraryClient
|
||||||
|
from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError
|
from llama_stack.distribution.datatypes import AuthenticationRequiredError
|
||||||
|
|
||||||
AUTH_TOKEN = "test-token"
|
AUTH_TOKEN = "test-token"
|
||||||
|
|
||||||
|
from tests.common.mcp import MCP_TOOLGROUP_ID, make_mcp_server
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
def mcp_server():
|
def mcp_server():
|
||||||
server = FastMCP("FastMCP Test Server")
|
with make_mcp_server(required_auth_token=AUTH_TOKEN) as mcp_server_info:
|
||||||
|
yield mcp_server_info
|
||||||
@server.tool()
|
|
||||||
async def greet_everyone(
|
|
||||||
url: str, ctx: Context
|
|
||||||
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
|
||||||
return [types.TextContent(type="text", text="Hello, world!")]
|
|
||||||
|
|
||||||
sse = SseServerTransport("/messages/")
|
|
||||||
|
|
||||||
async def handle_sse(request):
|
|
||||||
auth_header = request.headers.get("Authorization")
|
|
||||||
auth_token = None
|
|
||||||
if auth_header and auth_header.startswith("Bearer "):
|
|
||||||
auth_token = auth_header.split(" ")[1]
|
|
||||||
|
|
||||||
if auth_token != AUTH_TOKEN:
|
|
||||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
||||||
|
|
||||||
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
|
|
||||||
await server._mcp_server.run(
|
|
||||||
streams[0],
|
|
||||||
streams[1],
|
|
||||||
server._mcp_server.create_initialization_options(),
|
|
||||||
)
|
|
||||||
return Response()
|
|
||||||
|
|
||||||
app = Starlette(
|
|
||||||
routes=[
|
|
||||||
Route("/sse", endpoint=handle_sse),
|
|
||||||
Mount("/messages/", app=sse.handle_post_message),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_open_port():
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
|
||||||
sock.bind(("", 0))
|
|
||||||
return sock.getsockname()[1]
|
|
||||||
|
|
||||||
port = get_open_port()
|
|
||||||
|
|
||||||
config = uvicorn.Config(app, host="0.0.0.0", port=port)
|
|
||||||
server_instance = uvicorn.Server(config)
|
|
||||||
app.state.uvicorn_server = server_instance
|
|
||||||
|
|
||||||
def run_server():
|
|
||||||
server_instance.run()
|
|
||||||
|
|
||||||
# Start the server in a new thread
|
|
||||||
server_thread = threading.Thread(target=run_server, daemon=True)
|
|
||||||
server_thread.start()
|
|
||||||
|
|
||||||
# Polling until the server is ready
|
|
||||||
timeout = 10
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
while time.time() - start_time < timeout:
|
|
||||||
try:
|
|
||||||
response = httpx.get(f"http://localhost:{port}/sse")
|
|
||||||
if response.status_code == 401:
|
|
||||||
break
|
|
||||||
except httpx.RequestError:
|
|
||||||
pass
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
yield port
|
|
||||||
|
|
||||||
# Tell server to exit
|
|
||||||
server_instance.should_exit = True
|
|
||||||
server_thread.join(timeout=5)
|
|
||||||
|
|
||||||
|
|
||||||
def test_mcp_invocation(llama_stack_client, mcp_server):
|
def test_mcp_invocation(llama_stack_client, mcp_server):
|
||||||
port = mcp_server
|
if not isinstance(llama_stack_client, LlamaStackAsLibraryClient):
|
||||||
test_toolgroup_id = "remote::mcptest"
|
pytest.skip("The local MCP server only reliably reachable from library client.")
|
||||||
|
|
||||||
|
test_toolgroup_id = MCP_TOOLGROUP_ID
|
||||||
|
uri = mcp_server["server_url"]
|
||||||
|
|
||||||
# registering itself should fail since it requires listing tools
|
# registering itself should fail since it requires listing tools
|
||||||
with pytest.raises(Exception, match="Unauthorized"):
|
with pytest.raises(Exception, match="Unauthorized"):
|
||||||
llama_stack_client.toolgroups.register(
|
llama_stack_client.toolgroups.register(
|
||||||
toolgroup_id=test_toolgroup_id,
|
toolgroup_id=test_toolgroup_id,
|
||||||
provider_id="model-context-protocol",
|
provider_id="model-context-protocol",
|
||||||
mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"),
|
mcp_endpoint=dict(uri=uri),
|
||||||
)
|
)
|
||||||
|
|
||||||
provider_data = {
|
provider_data = {
|
||||||
"mcp_headers": {
|
"mcp_headers": {
|
||||||
f"http://localhost:{port}/sse": [
|
uri: [
|
||||||
f"Authorization: Bearer {AUTH_TOKEN}",
|
f"Authorization: Bearer {AUTH_TOKEN}",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
|
@ -133,24 +59,18 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
|
||||||
llama_stack_client.toolgroups.register(
|
llama_stack_client.toolgroups.register(
|
||||||
toolgroup_id=test_toolgroup_id,
|
toolgroup_id=test_toolgroup_id,
|
||||||
provider_id="model-context-protocol",
|
provider_id="model-context-protocol",
|
||||||
mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"),
|
mcp_endpoint=dict(uri=uri),
|
||||||
extra_headers=auth_headers,
|
extra_headers=auth_headers,
|
||||||
)
|
)
|
||||||
response = llama_stack_client.tools.list(
|
response = llama_stack_client.tools.list(
|
||||||
toolgroup_id=test_toolgroup_id,
|
toolgroup_id=test_toolgroup_id,
|
||||||
extra_headers=auth_headers,
|
extra_headers=auth_headers,
|
||||||
)
|
)
|
||||||
assert len(response) == 1
|
assert len(response) == 2
|
||||||
assert response[0].identifier == "greet_everyone"
|
assert {t.identifier for t in response} == {"greet_everyone", "get_boiling_point"}
|
||||||
assert response[0].type == "tool"
|
|
||||||
assert len(response[0].parameters) == 1
|
|
||||||
p = response[0].parameters[0]
|
|
||||||
assert p.name == "url"
|
|
||||||
assert p.parameter_type == "string"
|
|
||||||
assert p.required
|
|
||||||
|
|
||||||
response = llama_stack_client.tool_runtime.invoke_tool(
|
response = llama_stack_client.tool_runtime.invoke_tool(
|
||||||
tool_name=response[0].identifier,
|
tool_name="greet_everyone",
|
||||||
kwargs=dict(url="https://www.google.com"),
|
kwargs=dict(url="https://www.google.com"),
|
||||||
extra_headers=auth_headers,
|
extra_headers=auth_headers,
|
||||||
)
|
)
|
||||||
|
@ -159,7 +79,9 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
|
||||||
assert content[0].type == "text"
|
assert content[0].type == "text"
|
||||||
assert content[0].text == "Hello, world!"
|
assert content[0].text == "Hello, world!"
|
||||||
|
|
||||||
models = llama_stack_client.models.list()
|
models = [
|
||||||
|
m for m in llama_stack_client.models.list() if m.model_type == ModelType.llm and "guard" not in m.identifier
|
||||||
|
]
|
||||||
model_id = models[0].identifier
|
model_id = models[0].identifier
|
||||||
print(f"Using model: {model_id}")
|
print(f"Using model: {model_id}")
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
|
@ -174,7 +96,7 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "Yo. Use tools.",
|
"content": "Say hi to the world. Use tools to do so.",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
stream=False,
|
stream=False,
|
||||||
|
@ -196,7 +118,6 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
|
||||||
|
|
||||||
third = steps[2]
|
third = steps[2]
|
||||||
assert third.step_type == "inference"
|
assert third.step_type == "inference"
|
||||||
assert len(third.api_model_response.tool_calls) == 0
|
|
||||||
|
|
||||||
# when streaming, we currently don't check auth headers upfront and fail the request
|
# when streaming, we currently don't check auth headers upfront and fail the request
|
||||||
# early. but we should at least be generating a 401 later in the process.
|
# early. but we should at least be generating a 401 later in the process.
|
||||||
|
@ -205,7 +126,7 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "Yo. Use tools.",
|
"content": "What is the boiling point of polyjuice? Use tools to answer.",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
stream=True,
|
stream=True,
|
||||||
|
|
|
@ -4,88 +4,22 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import socket
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import mcp.types as types
|
|
||||||
import pytest
|
import pytest
|
||||||
import uvicorn
|
|
||||||
from mcp.server.fastmcp import Context, FastMCP
|
from llama_stack import LlamaStackAsLibraryClient
|
||||||
from mcp.server.sse import SseServerTransport
|
from tests.common.mcp import MCP_TOOLGROUP_ID, make_mcp_server
|
||||||
from starlette.applications import Starlette
|
|
||||||
from starlette.routing import Mount, Route
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
def test_register_and_unregister_toolgroup(llama_stack_client):
|
||||||
def mcp_server():
|
# TODO: make this work for http client also but you need to ensure
|
||||||
server = FastMCP("FastMCP Test Server")
|
# the MCP server is reachable from llama stack server
|
||||||
|
if not isinstance(llama_stack_client, LlamaStackAsLibraryClient):
|
||||||
|
pytest.skip("The local MCP server only reliably reachable from library client.")
|
||||||
|
|
||||||
@server.tool()
|
test_toolgroup_id = MCP_TOOLGROUP_ID
|
||||||
async def fetch(url: str, ctx: Context) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
|
||||||
headers = {"User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)"}
|
|
||||||
async with httpx.AsyncClient(follow_redirects=True, headers=headers) as client:
|
|
||||||
response = await client.get(url)
|
|
||||||
response.raise_for_status()
|
|
||||||
return [types.TextContent(type="text", text=response.text)]
|
|
||||||
|
|
||||||
sse = SseServerTransport("/messages/")
|
|
||||||
|
|
||||||
async def handle_sse(request):
|
|
||||||
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
|
|
||||||
await server._mcp_server.run(
|
|
||||||
streams[0],
|
|
||||||
streams[1],
|
|
||||||
server._mcp_server.create_initialization_options(),
|
|
||||||
)
|
|
||||||
|
|
||||||
app = Starlette(
|
|
||||||
debug=True,
|
|
||||||
routes=[
|
|
||||||
Route("/sse", endpoint=handle_sse),
|
|
||||||
Mount("/messages/", app=sse.handle_post_message),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_open_port():
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
|
||||||
sock.bind(("", 0))
|
|
||||||
return sock.getsockname()[1]
|
|
||||||
|
|
||||||
port = get_open_port()
|
|
||||||
|
|
||||||
def run_server():
|
|
||||||
uvicorn.run(app, host="0.0.0.0", port=port)
|
|
||||||
|
|
||||||
# Start the server in a new thread
|
|
||||||
server_thread = threading.Thread(target=run_server, daemon=True)
|
|
||||||
server_thread.start()
|
|
||||||
|
|
||||||
# Polling until the server is ready
|
|
||||||
timeout = 10
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
while time.time() - start_time < timeout:
|
|
||||||
try:
|
|
||||||
response = httpx.get(f"http://localhost:{port}/sse")
|
|
||||||
if response.status_code == 200:
|
|
||||||
break
|
|
||||||
except (httpx.RequestError, httpx.HTTPStatusError):
|
|
||||||
pass
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
yield port
|
|
||||||
|
|
||||||
|
|
||||||
def test_register_and_unregister_toolgroup(llama_stack_client, mcp_server):
|
|
||||||
"""
|
|
||||||
Integration test for registering and unregistering a toolgroup using the ToolGroups API.
|
|
||||||
"""
|
|
||||||
port = mcp_server
|
|
||||||
test_toolgroup_id = "remote::web-fetch"
|
|
||||||
provider_id = "model-context-protocol"
|
provider_id = "model-context-protocol"
|
||||||
|
|
||||||
|
with make_mcp_server() as mcp_server_info:
|
||||||
# Cleanup before running the test
|
# Cleanup before running the test
|
||||||
toolgroups = llama_stack_client.toolgroups.list()
|
toolgroups = llama_stack_client.toolgroups.list()
|
||||||
for toolgroup in toolgroups:
|
for toolgroup in toolgroups:
|
||||||
|
@ -96,7 +30,7 @@ def test_register_and_unregister_toolgroup(llama_stack_client, mcp_server):
|
||||||
llama_stack_client.toolgroups.register(
|
llama_stack_client.toolgroups.register(
|
||||||
toolgroup_id=test_toolgroup_id,
|
toolgroup_id=test_toolgroup_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"),
|
mcp_endpoint=dict(uri=mcp_server_info["server_url"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify registration
|
# Verify registration
|
||||||
|
|
|
@ -31,6 +31,18 @@ test_response_web_search:
|
||||||
search_context_size: "low"
|
search_context_size: "low"
|
||||||
output: "128"
|
output: "128"
|
||||||
|
|
||||||
|
test_response_mcp_tool:
|
||||||
|
test_name: test_response_mcp_tool
|
||||||
|
test_params:
|
||||||
|
case:
|
||||||
|
- case_id: "boiling_point_tool"
|
||||||
|
input: "What is the boiling point of polyjuice?"
|
||||||
|
tools:
|
||||||
|
- type: mcp
|
||||||
|
server_label: "localmcp"
|
||||||
|
server_url: "<FILLED_BY_TEST_RUNNER>"
|
||||||
|
output: "Hello, world!"
|
||||||
|
|
||||||
test_response_custom_tool:
|
test_response_custom_tool:
|
||||||
test_name: test_response_custom_tool
|
test_name: test_response_custom_tool
|
||||||
test_params:
|
test_params:
|
||||||
|
|
|
@ -4,9 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack import LlamaStackAsLibraryClient
|
||||||
|
from llama_stack.distribution.datatypes import AuthenticationRequiredError
|
||||||
|
from tests.common.mcp import make_mcp_server
|
||||||
from tests.verifications.openai_api.fixtures.fixtures import (
|
from tests.verifications.openai_api.fixtures.fixtures import (
|
||||||
case_id_generator,
|
case_id_generator,
|
||||||
get_base_test_name,
|
get_base_test_name,
|
||||||
|
@ -124,6 +129,79 @@ def test_response_non_streaming_web_search(request, openai_client, model, provid
|
||||||
assert case["output"].lower() in response.output_text.lower().strip()
|
assert case["output"].lower() in response.output_text.lower().strip()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"case",
|
||||||
|
responses_test_cases["test_response_mcp_tool"]["test_params"]["case"],
|
||||||
|
ids=case_id_generator,
|
||||||
|
)
|
||||||
|
def test_response_non_streaming_mcp_tool(request, openai_client, model, provider, verification_config, case):
|
||||||
|
test_name_base = get_base_test_name(request)
|
||||||
|
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||||
|
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||||
|
|
||||||
|
with make_mcp_server() as mcp_server_info:
|
||||||
|
tools = case["tools"]
|
||||||
|
for tool in tools:
|
||||||
|
if tool["type"] == "mcp":
|
||||||
|
tool["server_url"] = mcp_server_info["server_url"]
|
||||||
|
|
||||||
|
response = openai_client.responses.create(
|
||||||
|
model=model,
|
||||||
|
input=case["input"],
|
||||||
|
tools=tools,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
assert len(response.output) >= 3
|
||||||
|
list_tools = response.output[0]
|
||||||
|
assert list_tools.type == "mcp_list_tools"
|
||||||
|
assert list_tools.server_label == "localmcp"
|
||||||
|
assert len(list_tools.tools) == 2
|
||||||
|
assert {t["name"] for t in list_tools.tools} == {"get_boiling_point", "greet_everyone"}
|
||||||
|
|
||||||
|
call = response.output[1]
|
||||||
|
assert call.type == "mcp_call"
|
||||||
|
assert call.name == "get_boiling_point"
|
||||||
|
assert json.loads(call.arguments) == {"liquid_name": "polyjuice", "celcius": True}
|
||||||
|
assert call.error is None
|
||||||
|
assert "-100" in call.output
|
||||||
|
|
||||||
|
message = response.output[2]
|
||||||
|
text_content = message.content[0].text
|
||||||
|
assert "boiling point" in text_content.lower()
|
||||||
|
|
||||||
|
with make_mcp_server(required_auth_token="test-token") as mcp_server_info:
|
||||||
|
tools = case["tools"]
|
||||||
|
for tool in tools:
|
||||||
|
if tool["type"] == "mcp":
|
||||||
|
tool["server_url"] = mcp_server_info["server_url"]
|
||||||
|
|
||||||
|
exc_type = (
|
||||||
|
AuthenticationRequiredError
|
||||||
|
if isinstance(openai_client, LlamaStackAsLibraryClient)
|
||||||
|
else httpx.HTTPStatusError
|
||||||
|
)
|
||||||
|
with pytest.raises(exc_type):
|
||||||
|
openai_client.responses.create(
|
||||||
|
model=model,
|
||||||
|
input=case["input"],
|
||||||
|
tools=tools,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
for tool in tools:
|
||||||
|
if tool["type"] == "mcp":
|
||||||
|
tool["server_url"] = mcp_server_info["server_url"]
|
||||||
|
tool["headers"] = {"Authorization": "Bearer test-token"}
|
||||||
|
|
||||||
|
response = openai_client.responses.create(
|
||||||
|
model=model,
|
||||||
|
input=case["input"],
|
||||||
|
tools=tools,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
assert len(response.output) >= 3
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"case",
|
"case",
|
||||||
responses_test_cases["test_response_custom_tool"]["test_params"]["case"],
|
responses_test_cases["test_response_custom_tool"]["test_params"]["case"],
|
||||||
|
|
2
uv.lock
generated
2
uv.lock
generated
|
@ -1544,6 +1544,7 @@ unit = [
|
||||||
{ name = "aiohttp" },
|
{ name = "aiohttp" },
|
||||||
{ name = "aiosqlite" },
|
{ name = "aiosqlite" },
|
||||||
{ name = "chardet" },
|
{ name = "chardet" },
|
||||||
|
{ name = "mcp" },
|
||||||
{ name = "openai" },
|
{ name = "openai" },
|
||||||
{ name = "opentelemetry-exporter-otlp-proto-http" },
|
{ name = "opentelemetry-exporter-otlp-proto-http" },
|
||||||
{ name = "pypdf" },
|
{ name = "pypdf" },
|
||||||
|
@ -1576,6 +1577,7 @@ requires-dist = [
|
||||||
{ name = "llama-stack-client", specifier = ">=0.2.7" },
|
{ name = "llama-stack-client", specifier = ">=0.2.7" },
|
||||||
{ name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.7" },
|
{ name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.7" },
|
||||||
{ name = "mcp", marker = "extra == 'test'" },
|
{ name = "mcp", marker = "extra == 'test'" },
|
||||||
|
{ name = "mcp", marker = "extra == 'unit'" },
|
||||||
{ name = "myst-parser", marker = "extra == 'docs'" },
|
{ name = "myst-parser", marker = "extra == 'docs'" },
|
||||||
{ name = "nbval", marker = "extra == 'dev'" },
|
{ name = "nbval", marker = "extra == 'dev'" },
|
||||||
{ name = "openai", specifier = ">=1.66" },
|
{ name = "openai", specifier = ">=1.66" },
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue