add basic integration test

This commit is contained in:
Ashwin Bharambe 2025-05-20 18:20:16 -07:00
parent 6e57929ede
commit e6ddf5dac7
43 changed files with 342 additions and 44 deletions

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Protocol
from typing import Any, Protocol, runtime_checkable
from urllib.parse import urlparse
from pydantic import BaseModel, Field
@ -112,6 +112,7 @@ class ProviderSpec(BaseModel):
return self.provider_type in ("sample", "remote::sample")
@runtime_checkable
class RoutingTable(Protocol):
def get_provider_impl(self, routing_key: str) -> Any: ...

View file

@ -83,6 +83,5 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig",
pip_packages=["mcp"],
),
api_dependencies=[Api.credentials],
),
]

View file

@ -4,16 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from typing import Any, cast
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import urlparse
import exceptiongroup
import httpx
from mcp import ClientSession
from mcp import types as mcp_types
from mcp.client.sse import sse_client
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.content_types import URL, ImageContentItem, TextContentItem
from llama_stack.apis.datatypes import Api
from llama_stack.apis.tools import (
ListToolDefsResponse,
@ -23,12 +24,16 @@ from llama_stack.apis.tools import (
ToolRuntime,
)
from llama_stack.distribution.credentials import AuthenticationRequiredError, CredentialsStore
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import ModelContextProtocolConfig
logger = get_logger(__name__, category="tools")
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, None]:
@asynccontextmanager
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
try:
async with sse_client(endpoint, headers=headers) as streams:
async with ClientSession(*streams) as session:
@ -48,9 +53,13 @@ async def sse_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGen
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
def __init__(self, config: ModelContextProtocolConfig, deps: dict[Api, Any]):
# HACK: this is filled in by the Stack resolver magically right now to work around
# circular dependency issues.
credentials_store: CredentialsStore
def __init__(self, config: ModelContextProtocolConfig, _deps: dict[Api, Any]):
self.config = config
self.credentials_store = cast(CredentialsStore, deps[Api.credentials])
self.credentials_store = None
async def initialize(self):
pass
@ -99,14 +108,27 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async with sse_client_wrapper(endpoint, headers) as session:
result = await session.call_tool(tool.identifier, kwargs)
content = []
for item in result.content:
if isinstance(item, mcp_types.TextContent):
content.append(TextContentItem(text=item.text))
elif isinstance(item, mcp_types.ImageContent):
content.append(ImageContentItem(image=item.data))
elif isinstance(item, mcp_types.EmbeddedResource):
logger.warning(f"EmbeddedResource is not supported: {item}")
else:
raise ValueError(f"Unknown content type: {type(item)}")
return ToolInvocationResult(
content=[result.model_dump_json() for result in result.content],
content=content,
error_code=1 if result.isError else 0,
)
async def get_headers(self) -> dict[str, str]:
if self.credentials_store is None:
raise ValueError("credentials_store is not set")
headers = {}
credentials = await self.credentials_store.get_credential(self.__provider_id__)
if credentials:
headers["Authorization"] = f"Bearer {credentials.token}"
token = await self.credentials_store.read_decrypted_credential(self.__provider_id__)
if token:
headers["Authorization"] = f"Bearer {token}"
return headers