add basic integration test

This commit is contained in:
Ashwin Bharambe 2025-05-20 18:20:16 -07:00
parent 6e57929ede
commit e6ddf5dac7
43 changed files with 342 additions and 44 deletions

View file

@ -106,10 +106,10 @@ class DistributionCredentialsImpl(CredentialsAPI, CredentialsStore):
principal = get_principal()
# check that provider_id is registered
run_config = self.deps[Api.inspect].run_config
run_config = self.deps[Api.inspect].config.run_config
# TODO: we should make provider_ids unique across all APIs which is not enforced yet
provider_ids = [p.provider_id for p in run_config.providers.values()]
provider_ids = [p.provider_id for plist in run_config.providers.values() for p in plist]
if provider_id not in provider_ids:
raise ValueError(f"Provider {provider_id} is not registered")

View file

@ -9,6 +9,7 @@ from typing import Any
from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.credentials import Credentials
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval
@ -61,6 +62,7 @@ class InvalidProviderError(Exception):
def api_protocol_map() -> dict[Api, Any]:
return {
Api.providers: ProvidersAPI,
Api.credentials: Credentials,
Api.agents: Agents,
Api.inference: Inference,
Api.inspect: Inspect,

View file

@ -41,9 +41,11 @@ from llama_stack.distribution.inspect import DistributionInspectConfig, Distribu
from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
logger = get_logger(name=__name__, category="core")
@ -214,8 +216,12 @@ async def instantiate_internal_impls(impls: dict[Api, Any], run_config: StackRun
)
await providers_impl.initialize()
# TODO: make metadata_store and credentials_store non-optional by including it in the templates
credentials_impl = DistributionCredentialsImpl(
DistributionCredentialsConfig(kvstore=run_config.credentials_store),
DistributionCredentialsConfig(
kvstore=run_config.credentials_store
or SqliteKVStoreConfig(db_path=(DISTRIBS_BASE_DIR / run_config.image_name / "credentials.db").as_posix())
),
deps=impls,
)
await credentials_impl.initialize()
@ -231,18 +237,26 @@ async def instantiate_internal_impls(impls: dict[Api, Any], run_config: StackRun
async def construct_stack(
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
) -> dict[Api, Any]:
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
kvstore_config = run_config.metadata_store or SqliteKVStoreConfig(
db_path=(DISTRIBS_BASE_DIR / run_config.image_name / "kvstore.db").as_posix()
)
dist_registry, _ = await create_dist_registry(kvstore_config)
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
# Add internal implementations after all other providers are resolved
internal_impls = await instantiate_internal_impls(impls, run_config)
impls.update(internal_impls)
# credentials_store = internal_impls[Api.credentials]
# for impl in impls.values():
# # in an ideal world, we would pass the credentials store as a dependency
# if hasattr(impl, "credentials_store"):
# impl.credentials_store = credentials_store
# HACK: this is a hack to work around circular dependency issues. we probably need to
# make resolving internal implementations be part of `resolve_impls` again (as it used to be
# a while ago) so that dependencies can be expressed properly.
for impl in impls.values():
from llama_stack.distribution.routers.routing_tables import CommonRoutingTableImpl
if isinstance(impl, CommonRoutingTableImpl):
for provider_impl in impl.impls_by_provider_id.values():
if hasattr(provider_impl, "credentials_store"):
provider_impl.credentials_store = internal_impls[Api.credentials]
await register_resources(run_config, impls)
return impls

View file

@ -11,10 +11,8 @@ from typing import Protocol
import pydantic
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
logger = get_logger(__name__, category="core")
@ -189,16 +187,9 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
async def create_dist_registry(
metadata_store: KVStoreConfig | None,
image_name: str,
kvstore_config: KVStoreConfig,
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
# instantiate kvstore for storing and retrieving distribution metadata
if metadata_store:
dist_kvstore = await kvstore_impl(metadata_store)
else:
dist_kvstore = await kvstore_impl(
SqliteKVStoreConfig(db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix())
)
dist_kvstore = await kvstore_impl(kvstore_config)
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
await dist_registry.initialize()
return dist_registry, dist_kvstore

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Protocol
from typing import Any, Protocol, runtime_checkable
from urllib.parse import urlparse
from pydantic import BaseModel, Field
@ -112,6 +112,7 @@ class ProviderSpec(BaseModel):
return self.provider_type in ("sample", "remote::sample")
@runtime_checkable
class RoutingTable(Protocol):
def get_provider_impl(self, routing_key: str) -> Any: ...

View file

@ -83,6 +83,5 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig",
pip_packages=["mcp"],
),
api_dependencies=[Api.credentials],
),
]

View file

@ -4,16 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from typing import Any, cast
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import urlparse
import exceptiongroup
import httpx
from mcp import ClientSession
from mcp import types as mcp_types
from mcp.client.sse import sse_client
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.content_types import URL, ImageContentItem, TextContentItem
from llama_stack.apis.datatypes import Api
from llama_stack.apis.tools import (
ListToolDefsResponse,
@ -23,12 +24,16 @@ from llama_stack.apis.tools import (
ToolRuntime,
)
from llama_stack.distribution.credentials import AuthenticationRequiredError, CredentialsStore
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import ModelContextProtocolConfig
logger = get_logger(__name__, category="tools")
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, None]:
@asynccontextmanager
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
try:
async with sse_client(endpoint, headers=headers) as streams:
async with ClientSession(*streams) as session:
@ -48,9 +53,13 @@ async def sse_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGen
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
def __init__(self, config: ModelContextProtocolConfig, deps: dict[Api, Any]):
# HACK: this is filled in by the Stack resolver magically right now to work around
# circular dependency issues.
credentials_store: CredentialsStore
def __init__(self, config: ModelContextProtocolConfig, _deps: dict[Api, Any]):
self.config = config
self.credentials_store = cast(CredentialsStore, deps[Api.credentials])
self.credentials_store = None
async def initialize(self):
pass
@ -99,14 +108,27 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async with sse_client_wrapper(endpoint, headers) as session:
result = await session.call_tool(tool.identifier, kwargs)
content = []
for item in result.content:
if isinstance(item, mcp_types.TextContent):
content.append(TextContentItem(text=item.text))
elif isinstance(item, mcp_types.ImageContent):
content.append(ImageContentItem(image=item.data))
elif isinstance(item, mcp_types.EmbeddedResource):
logger.warning(f"EmbeddedResource is not supported: {item}")
else:
raise ValueError(f"Unknown content type: {type(item)}")
return ToolInvocationResult(
content=[result.model_dump_json() for result in result.content],
content=content,
error_code=1 if result.isError else 0,
)
async def get_headers(self) -> dict[str, str]:
if self.credentials_store is None:
raise ValueError("credentials_store is not set")
headers = {}
credentials = await self.credentials_store.get_credential(self.__provider_id__)
if credentials:
headers["Authorization"] = f"Bearer {credentials.token}"
token = await self.credentials_store.read_decrypted_credential(self.__provider_id__)
if token:
headers["Authorization"] = f"Bearer {token}"
return headers

View file

@ -96,6 +96,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/credentials.db
models:
- metadata: {}
model_id: meta.llama3-1-8b-instruct-v1:0

View file

@ -99,6 +99,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/credentials.db
models:
- metadata: {}
model_id: llama3.1-8b

View file

@ -99,6 +99,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/credentials.db
models:
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p1-8b-instruct

View file

@ -99,6 +99,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/credentials.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}

View file

@ -95,6 +95,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/credentials.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}

View file

@ -111,6 +111,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/credentials.db
models:
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p1-8b-instruct

View file

@ -106,6 +106,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/credentials.db
models:
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p1-8b-instruct

View file

@ -99,6 +99,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/credentials.db
models:
- metadata: {}
model_id: groq/llama3-8b-8192

View file

@ -107,6 +107,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/credentials.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}

View file

@ -102,6 +102,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/credentials.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}

View file

@ -107,6 +107,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/credentials.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}

View file

@ -102,6 +102,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/credentials.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}

View file

@ -111,6 +111,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/credentials.db
models:
- metadata: {}
model_id: Llama-3.3-70B-Instruct

View file

@ -117,6 +117,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/credentials.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}

View file

@ -107,6 +107,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/credentials.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}

View file

@ -92,6 +92,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/credentials.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}

View file

@ -80,6 +80,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/credentials.db
models:
- metadata: {}
model_id: meta/llama3-8b-instruct

View file

@ -112,6 +112,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/credentials.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}

View file

@ -110,6 +110,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/credentials.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}

View file

@ -125,6 +125,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/credentials.db
models:
- metadata: {}
model_id: openai/gpt-4o

View file

@ -111,6 +111,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/credentials.db
models:
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct

View file

@ -106,6 +106,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/credentials.db
models:
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct

View file

@ -115,6 +115,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/credentials.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}

View file

@ -108,6 +108,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/credentials.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}

View file

@ -81,6 +81,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/credentials.db
models:
- metadata: {}
model_id: sambanova/Meta-Llama-3.1-8B-Instruct

View file

@ -133,6 +133,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/credentials.db
models:
- metadata: {}
model_id: openai/gpt-4o

View file

@ -117,6 +117,10 @@ class RunConfigSettings(BaseModel):
__distro_dir__=f"~/.llama/distributions/{name}",
db_name="registry.db",
),
credentials_store=SqliteKVStoreConfig.sample_run_config(
__distro_dir__=f"~/.llama/distributions/{name}",
db_name="credentials.db",
),
models=self.default_models or [],
shields=self.default_shields or [],
tool_groups=self.default_tool_groups or [],

View file

@ -102,6 +102,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/credentials.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}

View file

@ -101,6 +101,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/credentials.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}

View file

@ -111,6 +111,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/credentials.db
models:
- metadata: {}
model_id: meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo

View file

@ -106,6 +106,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/credentials.db
models:
- metadata: {}
model_id: meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo

View file

@ -135,6 +135,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/verification}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/verification}/credentials.db
models:
- metadata: {}
model_id: openai/gpt-4o

View file

@ -106,6 +106,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/vllm-gpu}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/vllm-gpu}/credentials.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}

View file

@ -103,6 +103,9 @@ providers:
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/registry.db
credentials_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/credentials.db
models:
- metadata: {}
model_id: meta-llama/llama-3-3-70b-instruct

View file

@ -0,0 +1,178 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import socket
import threading
import time
import httpx
import mcp.types as types
import pytest
import uvicorn
from llama_stack_client import Agent
from mcp.server.fastmcp import Context, FastMCP
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.exceptions import HTTPException
from starlette.responses import Response
from starlette.routing import Mount, Route
AUTH_TOKEN = "test-token"
@pytest.fixture(scope="module")
def mcp_server():
server = FastMCP("FastMCP Test Server")
@server.tool()
async def greet_everyone(
url: str, ctx: Context
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
return [types.TextContent(type="text", text="Hello, world!")]
sse = SseServerTransport("/messages/")
async def handle_sse(request):
auth_header = request.headers.get("Authorization")
auth_token = None
if auth_header and auth_header.startswith("Bearer "):
auth_token = auth_header.split(" ")[1]
if auth_token != AUTH_TOKEN:
raise HTTPException(status_code=401, detail="Unauthorized")
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
await server._mcp_server.run(
streams[0],
streams[1],
server._mcp_server.create_initialization_options(),
)
return Response()
app = Starlette(
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
)
def get_open_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("", 0))
return sock.getsockname()[1]
port = get_open_port()
config = uvicorn.Config(app, host="0.0.0.0", port=port)
server_instance = uvicorn.Server(config)
app.state.uvicorn_server = server_instance
def run_server():
server_instance.run()
# Start the server in a new thread
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()
# Polling until the server is ready
timeout = 10
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = httpx.get(f"http://localhost:{port}/sse")
if response.status_code == 401:
break
except httpx.RequestError:
pass
time.sleep(0.1)
yield port
# Tell server to exit
server_instance.should_exit = True
server_thread.join(timeout=5)
def test_mcp_invocation(llama_stack_client, mcp_server):
port = mcp_server
test_toolgroup_id = "remote::mcptest"
# registering itself should fail since it requires listing tools
with pytest.raises(Exception, match="Unauthorized"):
llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id,
provider_id="model-context-protocol",
mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"),
)
llama_stack_client.credentials.create(
provider_id="model-context-protocol",
token_type="access_token",
token=AUTH_TOKEN,
ttl_seconds=100,
)
llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id,
provider_id="model-context-protocol",
mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"),
)
response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
assert len(response) == 1
assert response[0].identifier == "greet_everyone"
assert response[0].type == "tool"
assert len(response[0].parameters) == 1
p = response[0].parameters[0]
assert p.name == "url"
assert p.parameter_type == "string"
assert p.required
response = llama_stack_client.tool_runtime.invoke_tool(
tool_name=response[0].identifier,
kwargs=dict(url="https://www.google.com"),
)
content = response.content
assert len(content) == 1
assert content[0].type == "text"
assert content[0].text == "Hello, world!"
models = llama_stack_client.models.list()
model_id = models[0].identifier
print(f"Using model: {model_id}")
agent = Agent(
client=llama_stack_client,
model=model_id,
instructions="You are a helpful assistant.",
tools=[test_toolgroup_id],
)
session_id = agent.create_session("test-session")
response = agent.create_turn(
session_id=session_id,
messages=[
{
"role": "user",
"content": "Yo. Use tools.",
}
],
stream=False,
)
steps = response.steps
first = steps[0]
assert first.step_type == "inference"
assert len(first.api_model_response.tool_calls) == 1
tool_call = first.api_model_response.tool_calls[0]
assert tool_call.tool_name == "greet_everyone"
second = steps[1]
assert second.step_type == "tool_execution"
tool_response_content = second.tool_responses[0].content
assert len(tool_response_content) == 1
assert tool_response_content[0].type == "text"
assert tool_response_content[0].text == "Hello, world!"
third = steps[2]
assert third.step_type == "inference"
assert len(third.api_model_response.tool_calls) == 0

View file

@ -1,12 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
def test_toolsgroups_unregister(llama_stack_client):
client = llama_stack_client
client.toolgroups.unregister(
toolgroup_id="builtin::websearch",
)