forked from phoenix-oss/llama-stack-mirror
When registering a MCP endpoint, we cannot list tools (like we used to) since the MCP endpoint may be behind an auth wall. Registration can happen much sooner (via run.yaml). Instead, we do listing only when the _user_ actually calls listing. Furthermore, we cache the list in-memory in the server. Currently, the cache is not invalidated -- we may want to periodically re-list for MCP servers. Note that they must call `list_tools` before calling `invoke_tool` -- we use this critically. This will enable us to list MCP servers in run.yaml ## Test Plan Existing tests, updated tests accordingly.
132 lines
4.3 KiB
Python
132 lines
4.3 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import json
|
|
|
|
import pytest
|
|
from llama_stack_client import Agent
|
|
|
|
from llama_stack import LlamaStackAsLibraryClient
|
|
from llama_stack.apis.models import ModelType
|
|
from llama_stack.distribution.datatypes import AuthenticationRequiredError
|
|
|
|
AUTH_TOKEN = "test-token"
|
|
|
|
from tests.common.mcp import MCP_TOOLGROUP_ID, make_mcp_server
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def mcp_server():
|
|
with make_mcp_server(required_auth_token=AUTH_TOKEN) as mcp_server_info:
|
|
yield mcp_server_info
|
|
|
|
|
|
def test_mcp_invocation(llama_stack_client, mcp_server):
|
|
if not isinstance(llama_stack_client, LlamaStackAsLibraryClient):
|
|
pytest.skip("The local MCP server only reliably reachable from library client.")
|
|
|
|
test_toolgroup_id = MCP_TOOLGROUP_ID
|
|
uri = mcp_server["server_url"]
|
|
|
|
# registering should not raise an error anymore even if you don't specify the auth token
|
|
llama_stack_client.toolgroups.register(
|
|
toolgroup_id=test_toolgroup_id,
|
|
provider_id="model-context-protocol",
|
|
mcp_endpoint=dict(uri=uri),
|
|
)
|
|
|
|
provider_data = {
|
|
"mcp_headers": {
|
|
uri: [
|
|
f"Authorization: Bearer {AUTH_TOKEN}",
|
|
],
|
|
},
|
|
}
|
|
auth_headers = {
|
|
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
|
|
}
|
|
|
|
with pytest.raises(Exception, match="Unauthorized"):
|
|
llama_stack_client.tools.list()
|
|
|
|
response = llama_stack_client.tools.list(
|
|
toolgroup_id=test_toolgroup_id,
|
|
extra_headers=auth_headers,
|
|
)
|
|
assert len(response) == 2
|
|
assert {t.identifier for t in response} == {"greet_everyone", "get_boiling_point"}
|
|
|
|
response = llama_stack_client.tool_runtime.invoke_tool(
|
|
tool_name="greet_everyone",
|
|
kwargs=dict(url="https://www.google.com"),
|
|
extra_headers=auth_headers,
|
|
)
|
|
content = response.content
|
|
assert len(content) == 1
|
|
assert content[0].type == "text"
|
|
assert content[0].text == "Hello, world!"
|
|
|
|
models = [
|
|
m for m in llama_stack_client.models.list() if m.model_type == ModelType.llm and "guard" not in m.identifier
|
|
]
|
|
model_id = models[0].identifier
|
|
print(f"Using model: {model_id}")
|
|
agent = Agent(
|
|
client=llama_stack_client,
|
|
model=model_id,
|
|
instructions="You are a helpful assistant.",
|
|
tools=[test_toolgroup_id],
|
|
)
|
|
session_id = agent.create_session("test-session")
|
|
response = agent.create_turn(
|
|
session_id=session_id,
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": "Say hi to the world. Use tools to do so.",
|
|
}
|
|
],
|
|
stream=False,
|
|
extra_headers=auth_headers,
|
|
)
|
|
steps = response.steps
|
|
first = steps[0]
|
|
assert first.step_type == "inference"
|
|
assert len(first.api_model_response.tool_calls) == 1
|
|
tool_call = first.api_model_response.tool_calls[0]
|
|
assert tool_call.tool_name == "greet_everyone"
|
|
|
|
second = steps[1]
|
|
assert second.step_type == "tool_execution"
|
|
tool_response_content = second.tool_responses[0].content
|
|
assert len(tool_response_content) == 1
|
|
assert tool_response_content[0].type == "text"
|
|
assert tool_response_content[0].text == "Hello, world!"
|
|
|
|
third = steps[2]
|
|
assert third.step_type == "inference"
|
|
|
|
# when streaming, we currently don't check auth headers upfront and fail the request
|
|
# early. but we should at least be generating a 401 later in the process.
|
|
response = agent.create_turn(
|
|
session_id=session_id,
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": "What is the boiling point of polyjuice? Use tools to answer.",
|
|
}
|
|
],
|
|
stream=True,
|
|
)
|
|
if isinstance(llama_stack_client, LlamaStackAsLibraryClient):
|
|
with pytest.raises(AuthenticationRequiredError):
|
|
for _ in response:
|
|
pass
|
|
else:
|
|
error_chunks = [chunk for chunk in response if "error" in chunk.model_dump()]
|
|
assert len(error_chunks) == 1
|
|
chunk = error_chunks[0].model_dump()
|
|
assert "Unauthorized" in chunk["error"]["message"]
|