From e6ddf5dac7e1cc6c8302697a57a5ae8461ccf849 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 20 May 2025 18:20:16 -0700 Subject: [PATCH] add basic integration test --- llama_stack/distribution/credentials.py | 4 +- llama_stack/distribution/resolver.py | 2 + llama_stack/distribution/stack.py | 28 ++- llama_stack/distribution/store/registry.py | 13 +- llama_stack/providers/datatypes.py | 3 +- .../providers/registry/tool_runtime.py | 1 - .../model_context_protocol.py | 42 ++++- llama_stack/templates/bedrock/run.yaml | 3 + llama_stack/templates/cerebras/run.yaml | 3 + llama_stack/templates/ci-tests/run.yaml | 3 + .../templates/dell/run-with-safety.yaml | 3 + llama_stack/templates/dell/run.yaml | 3 + .../templates/fireworks/run-with-safety.yaml | 3 + llama_stack/templates/fireworks/run.yaml | 3 + llama_stack/templates/groq/run.yaml | 3 + .../hf-endpoint/run-with-safety.yaml | 3 + llama_stack/templates/hf-endpoint/run.yaml | 3 + .../hf-serverless/run-with-safety.yaml | 3 + llama_stack/templates/hf-serverless/run.yaml | 3 + llama_stack/templates/llama_api/run.yaml | 3 + .../meta-reference-gpu/run-with-safety.yaml | 3 + .../templates/meta-reference-gpu/run.yaml | 3 + .../templates/nvidia/run-with-safety.yaml | 3 + llama_stack/templates/nvidia/run.yaml | 3 + .../templates/ollama/run-with-safety.yaml | 3 + llama_stack/templates/ollama/run.yaml | 3 + llama_stack/templates/open-benchmark/run.yaml | 3 + .../passthrough/run-with-safety.yaml | 3 + llama_stack/templates/passthrough/run.yaml | 3 + .../remote-vllm/run-with-safety.yaml | 3 + llama_stack/templates/remote-vllm/run.yaml | 3 + llama_stack/templates/sambanova/run.yaml | 3 + llama_stack/templates/starter/run.yaml | 3 + llama_stack/templates/template.py | 4 + .../templates/tgi/run-with-safety.yaml | 3 + llama_stack/templates/tgi/run.yaml | 3 + .../templates/together/run-with-safety.yaml | 3 + llama_stack/templates/together/run.yaml | 3 + llama_stack/templates/verification/run.yaml | 3 + llama_stack/templates/vllm-gpu/run.yaml | 3 + llama_stack/templates/watsonx/run.yaml | 3 + tests/integration/tool_runtime/test_mcp.py | 178 ++++++++++++++++++ tests/integration/tools/test_tools.py | 12 -- 43 files changed, 342 insertions(+), 44 deletions(-) create mode 100644 tests/integration/tool_runtime/test_mcp.py delete mode 100644 tests/integration/tools/test_tools.py diff --git a/llama_stack/distribution/credentials.py b/llama_stack/distribution/credentials.py index 5d117bebe..e04e30c9a 100644 --- a/llama_stack/distribution/credentials.py +++ b/llama_stack/distribution/credentials.py @@ -106,10 +106,10 @@ class DistributionCredentialsImpl(CredentialsAPI, CredentialsStore): principal = get_principal() # check that provider_id is registered - run_config = self.deps[Api.inspect].run_config + run_config = self.deps[Api.inspect].config.run_config # TODO: we should make provider_ids unique across all APIs which is not enforced yet - provider_ids = [p.provider_id for p in run_config.providers.values()] + provider_ids = [p.provider_id for plist in run_config.providers.values() for p in plist] if provider_id not in provider_ids: raise ValueError(f"Provider {provider_id} is not registered") diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 257c495c3..deae01e63 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -9,6 +9,7 @@ from typing import Any from llama_stack.apis.agents import Agents from llama_stack.apis.benchmarks import Benchmarks +from llama_stack.apis.credentials import Credentials from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets from llama_stack.apis.eval import Eval @@ -61,6 +62,7 @@ class InvalidProviderError(Exception): def api_protocol_map() -> dict[Api, Any]: return { Api.providers: ProvidersAPI, + Api.credentials: Credentials, Api.agents: Agents, Api.inference: Inference, Api.inspect: Inspect, diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index c4168e2b1..75c4fecaf 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -41,9 +41,11 @@ from llama_stack.distribution.inspect import DistributionInspectConfig, Distribu from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls from llama_stack.distribution.store.registry import create_dist_registry +from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig logger = get_logger(name=__name__, category="core") @@ -214,8 +216,12 @@ async def instantiate_internal_impls(impls: dict[Api, Any], run_config: StackRun ) await providers_impl.initialize() + # TODO: make metadata_store and credentials_store non-optional by including it in the templates credentials_impl = DistributionCredentialsImpl( - DistributionCredentialsConfig(kvstore=run_config.credentials_store), + DistributionCredentialsConfig( + kvstore=run_config.credentials_store + or SqliteKVStoreConfig(db_path=(DISTRIBS_BASE_DIR / run_config.image_name / "credentials.db").as_posix()) + ), deps=impls, ) await credentials_impl.initialize() @@ -231,18 +237,26 @@ async def instantiate_internal_impls(impls: dict[Api, Any], run_config: StackRun async def construct_stack( run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None ) -> dict[Api, Any]: - dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name) + kvstore_config = run_config.metadata_store or SqliteKVStoreConfig( + db_path=(DISTRIBS_BASE_DIR / run_config.image_name / "kvstore.db").as_posix() + ) + dist_registry, _ = await create_dist_registry(kvstore_config) impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry) # Add internal implementations after all other providers are resolved internal_impls = await instantiate_internal_impls(impls, run_config) impls.update(internal_impls) - # credentials_store = internal_impls[Api.credentials] - # for impl in impls.values(): - # # in an ideal world, we would pass the credentials store as a dependency - # if hasattr(impl, "credentials_store"): - # impl.credentials_store = credentials_store + # HACK: this is a hack to work around circular dependency issues. we probably need to + # make resolving internal implementations be part of `resolve_impls` again (as it used to be + # a while ago) so that dependencies can be expressed properly. + for impl in impls.values(): + from llama_stack.distribution.routers.routing_tables import CommonRoutingTableImpl + + if isinstance(impl, CommonRoutingTableImpl): + for provider_impl in impl.impls_by_provider_id.values(): + if hasattr(provider_impl, "credentials_store"): + provider_impl.credentials_store = internal_impls[Api.credentials] await register_resources(run_config, impls) return impls diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index a6b400136..9205077a7 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -11,10 +11,8 @@ from typing import Protocol import pydantic from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider -from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl -from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig logger = get_logger(__name__, category="core") @@ -189,16 +187,9 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): async def create_dist_registry( - metadata_store: KVStoreConfig | None, - image_name: str, + kvstore_config: KVStoreConfig, ) -> tuple[CachedDiskDistributionRegistry, KVStore]: - # instantiate kvstore for storing and retrieving distribution metadata - if metadata_store: - dist_kvstore = await kvstore_impl(metadata_store) - else: - dist_kvstore = await kvstore_impl( - SqliteKVStoreConfig(db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix()) - ) + dist_kvstore = await kvstore_impl(kvstore_config) dist_registry = CachedDiskDistributionRegistry(dist_kvstore) await dist_registry.initialize() return dist_registry, dist_kvstore diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 3e9806f23..2050f34bc 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import Any, Protocol +from typing import Any, Protocol, runtime_checkable from urllib.parse import urlparse from pydantic import BaseModel, Field @@ -112,6 +112,7 @@ class ProviderSpec(BaseModel): return self.provider_type in ("sample", "remote::sample") +@runtime_checkable class RoutingTable(Protocol): def get_provider_impl(self, routing_key: str) -> Any: ... diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index 2e789089b..b9194810e 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -83,6 +83,5 @@ def available_providers() -> list[ProviderSpec]: config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig", pip_packages=["mcp"], ), - api_dependencies=[Api.credentials], ), ] diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index 8ac3769d5..b6e3c0622 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -4,16 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections.abc import AsyncGenerator -from typing import Any, cast +from contextlib import asynccontextmanager +from typing import Any from urllib.parse import urlparse import exceptiongroup import httpx from mcp import ClientSession +from mcp import types as mcp_types from mcp.client.sse import sse_client -from llama_stack.apis.common.content_types import URL +from llama_stack.apis.common.content_types import URL, ImageContentItem, TextContentItem from llama_stack.apis.datatypes import Api from llama_stack.apis.tools import ( ListToolDefsResponse, @@ -23,12 +24,16 @@ from llama_stack.apis.tools import ( ToolRuntime, ) from llama_stack.distribution.credentials import AuthenticationRequiredError, CredentialsStore +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ToolsProtocolPrivate from .config import ModelContextProtocolConfig +logger = get_logger(__name__, category="tools") -async def sse_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, None]: + +@asynccontextmanager +async def sse_client_wrapper(endpoint: str, headers: dict[str, str]): try: async with sse_client(endpoint, headers=headers) as streams: async with ClientSession(*streams) as session: @@ -48,9 +53,13 @@ async def sse_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGen class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): - def __init__(self, config: ModelContextProtocolConfig, deps: dict[Api, Any]): + # HACK: this is filled in by the Stack resolver magically right now to work around + # circular dependency issues. + credentials_store: CredentialsStore + + def __init__(self, config: ModelContextProtocolConfig, _deps: dict[Api, Any]): self.config = config - self.credentials_store = cast(CredentialsStore, deps[Api.credentials]) + self.credentials_store = None async def initialize(self): pass @@ -99,14 +108,27 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): async with sse_client_wrapper(endpoint, headers) as session: result = await session.call_tool(tool.identifier, kwargs) + content = [] + for item in result.content: + if isinstance(item, mcp_types.TextContent): + content.append(TextContentItem(text=item.text)) + elif isinstance(item, mcp_types.ImageContent): + content.append(ImageContentItem(image=item.data)) + elif isinstance(item, mcp_types.EmbeddedResource): + logger.warning(f"EmbeddedResource is not supported: {item}") + else: + raise ValueError(f"Unknown content type: {type(item)}") return ToolInvocationResult( - content=[result.model_dump_json() for result in result.content], + content=content, error_code=1 if result.isError else 0, ) async def get_headers(self) -> dict[str, str]: + if self.credentials_store is None: + raise ValueError("credentials_store is not set") + headers = {} - credentials = await self.credentials_store.get_credential(self.__provider_id__) - if credentials: - headers["Authorization"] = f"Bearer {credentials.token}" + token = await self.credentials_store.read_decrypted_credential(self.__provider_id__) + if token: + headers["Authorization"] = f"Bearer {token}" return headers diff --git a/llama_stack/templates/bedrock/run.yaml b/llama_stack/templates/bedrock/run.yaml index 30599a6c0..2e5a12b8f 100644 --- a/llama_stack/templates/bedrock/run.yaml +++ b/llama_stack/templates/bedrock/run.yaml @@ -96,6 +96,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/credentials.db models: - metadata: {} model_id: meta.llama3-1-8b-instruct-v1:0 diff --git a/llama_stack/templates/cerebras/run.yaml b/llama_stack/templates/cerebras/run.yaml index 0731b1df9..4c10574a3 100644 --- a/llama_stack/templates/cerebras/run.yaml +++ b/llama_stack/templates/cerebras/run.yaml @@ -99,6 +99,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/credentials.db models: - metadata: {} model_id: llama3.1-8b diff --git a/llama_stack/templates/ci-tests/run.yaml b/llama_stack/templates/ci-tests/run.yaml index d9ee5b3cf..1083cb3d6 100644 --- a/llama_stack/templates/ci-tests/run.yaml +++ b/llama_stack/templates/ci-tests/run.yaml @@ -99,6 +99,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/credentials.db models: - metadata: {} model_id: accounts/fireworks/models/llama-v3p1-8b-instruct diff --git a/llama_stack/templates/dell/run-with-safety.yaml b/llama_stack/templates/dell/run-with-safety.yaml index 24c515112..6401e70a6 100644 --- a/llama_stack/templates/dell/run-with-safety.yaml +++ b/llama_stack/templates/dell/run-with-safety.yaml @@ -99,6 +99,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/credentials.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/dell/run.yaml b/llama_stack/templates/dell/run.yaml index fdece894f..eca7f18e3 100644 --- a/llama_stack/templates/dell/run.yaml +++ b/llama_stack/templates/dell/run.yaml @@ -95,6 +95,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/credentials.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/fireworks/run-with-safety.yaml b/llama_stack/templates/fireworks/run-with-safety.yaml index 0ab07613e..bed2eb3a8 100644 --- a/llama_stack/templates/fireworks/run-with-safety.yaml +++ b/llama_stack/templates/fireworks/run-with-safety.yaml @@ -111,6 +111,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/credentials.db models: - metadata: {} model_id: accounts/fireworks/models/llama-v3p1-8b-instruct diff --git a/llama_stack/templates/fireworks/run.yaml b/llama_stack/templates/fireworks/run.yaml index 81c293a46..814c8b7a4 100644 --- a/llama_stack/templates/fireworks/run.yaml +++ b/llama_stack/templates/fireworks/run.yaml @@ -106,6 +106,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/credentials.db models: - metadata: {} model_id: accounts/fireworks/models/llama-v3p1-8b-instruct diff --git a/llama_stack/templates/groq/run.yaml b/llama_stack/templates/groq/run.yaml index 79c350c73..50c91f1d7 100644 --- a/llama_stack/templates/groq/run.yaml +++ b/llama_stack/templates/groq/run.yaml @@ -99,6 +99,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/credentials.db models: - metadata: {} model_id: groq/llama3-8b-8192 diff --git a/llama_stack/templates/hf-endpoint/run-with-safety.yaml b/llama_stack/templates/hf-endpoint/run-with-safety.yaml index 82bcaa3cf..c5df7d5f5 100644 --- a/llama_stack/templates/hf-endpoint/run-with-safety.yaml +++ b/llama_stack/templates/hf-endpoint/run-with-safety.yaml @@ -107,6 +107,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/credentials.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/hf-endpoint/run.yaml b/llama_stack/templates/hf-endpoint/run.yaml index ec7c55032..9cbe6bbfd 100644 --- a/llama_stack/templates/hf-endpoint/run.yaml +++ b/llama_stack/templates/hf-endpoint/run.yaml @@ -102,6 +102,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/credentials.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/hf-serverless/run-with-safety.yaml b/llama_stack/templates/hf-serverless/run-with-safety.yaml index 320976e2c..9e91d393e 100644 --- a/llama_stack/templates/hf-serverless/run-with-safety.yaml +++ b/llama_stack/templates/hf-serverless/run-with-safety.yaml @@ -107,6 +107,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/credentials.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/hf-serverless/run.yaml b/llama_stack/templates/hf-serverless/run.yaml index 2b22b20c6..8a6ae9adb 100644 --- a/llama_stack/templates/hf-serverless/run.yaml +++ b/llama_stack/templates/hf-serverless/run.yaml @@ -102,6 +102,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/credentials.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/llama_api/run.yaml b/llama_stack/templates/llama_api/run.yaml index a879482d7..bb5deb185 100644 --- a/llama_stack/templates/llama_api/run.yaml +++ b/llama_stack/templates/llama_api/run.yaml @@ -111,6 +111,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/credentials.db models: - metadata: {} model_id: Llama-3.3-70B-Instruct diff --git a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml index 180d44e0f..1475ea745 100644 --- a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml +++ b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml @@ -117,6 +117,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/credentials.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/meta-reference-gpu/run.yaml b/llama_stack/templates/meta-reference-gpu/run.yaml index d879667e0..2bd06878f 100644 --- a/llama_stack/templates/meta-reference-gpu/run.yaml +++ b/llama_stack/templates/meta-reference-gpu/run.yaml @@ -107,6 +107,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/credentials.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/nvidia/run-with-safety.yaml b/llama_stack/templates/nvidia/run-with-safety.yaml index 3cdb8e3d2..38b7bda23 100644 --- a/llama_stack/templates/nvidia/run-with-safety.yaml +++ b/llama_stack/templates/nvidia/run-with-safety.yaml @@ -92,6 +92,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/credentials.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/nvidia/run.yaml b/llama_stack/templates/nvidia/run.yaml index 3337b7942..cda5fdf51 100644 --- a/llama_stack/templates/nvidia/run.yaml +++ b/llama_stack/templates/nvidia/run.yaml @@ -80,6 +80,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/credentials.db models: - metadata: {} model_id: meta/llama3-8b-instruct diff --git a/llama_stack/templates/ollama/run-with-safety.yaml b/llama_stack/templates/ollama/run-with-safety.yaml index 74d0822ca..0abfb66f9 100644 --- a/llama_stack/templates/ollama/run-with-safety.yaml +++ b/llama_stack/templates/ollama/run-with-safety.yaml @@ -112,6 +112,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/credentials.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/ollama/run.yaml b/llama_stack/templates/ollama/run.yaml index 71229be97..00c77f138 100644 --- a/llama_stack/templates/ollama/run.yaml +++ b/llama_stack/templates/ollama/run.yaml @@ -110,6 +110,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/credentials.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/open-benchmark/run.yaml b/llama_stack/templates/open-benchmark/run.yaml index 30a27cbd8..807adcde6 100644 --- a/llama_stack/templates/open-benchmark/run.yaml +++ b/llama_stack/templates/open-benchmark/run.yaml @@ -125,6 +125,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/credentials.db models: - metadata: {} model_id: openai/gpt-4o diff --git a/llama_stack/templates/passthrough/run-with-safety.yaml b/llama_stack/templates/passthrough/run-with-safety.yaml index a91b9fc92..c6e3d1eee 100644 --- a/llama_stack/templates/passthrough/run-with-safety.yaml +++ b/llama_stack/templates/passthrough/run-with-safety.yaml @@ -111,6 +111,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/credentials.db models: - metadata: {} model_id: meta-llama/Llama-3.1-8B-Instruct diff --git a/llama_stack/templates/passthrough/run.yaml b/llama_stack/templates/passthrough/run.yaml index d1dd3b885..ae591c24a 100644 --- a/llama_stack/templates/passthrough/run.yaml +++ b/llama_stack/templates/passthrough/run.yaml @@ -106,6 +106,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/credentials.db models: - metadata: {} model_id: meta-llama/Llama-3.1-8B-Instruct diff --git a/llama_stack/templates/remote-vllm/run-with-safety.yaml b/llama_stack/templates/remote-vllm/run-with-safety.yaml index 6931d4ba9..b087a3a6b 100644 --- a/llama_stack/templates/remote-vllm/run-with-safety.yaml +++ b/llama_stack/templates/remote-vllm/run-with-safety.yaml @@ -115,6 +115,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/credentials.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/remote-vllm/run.yaml b/llama_stack/templates/remote-vllm/run.yaml index 05671165d..53c488685 100644 --- a/llama_stack/templates/remote-vllm/run.yaml +++ b/llama_stack/templates/remote-vllm/run.yaml @@ -108,6 +108,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/credentials.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/sambanova/run.yaml b/llama_stack/templates/sambanova/run.yaml index 620d50307..9e6b8bd78 100644 --- a/llama_stack/templates/sambanova/run.yaml +++ b/llama_stack/templates/sambanova/run.yaml @@ -81,6 +81,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/credentials.db models: - metadata: {} model_id: sambanova/Meta-Llama-3.1-8B-Instruct diff --git a/llama_stack/templates/starter/run.yaml b/llama_stack/templates/starter/run.yaml index 402695850..02b0f4ae6 100644 --- a/llama_stack/templates/starter/run.yaml +++ b/llama_stack/templates/starter/run.yaml @@ -133,6 +133,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/credentials.db models: - metadata: {} model_id: openai/gpt-4o diff --git a/llama_stack/templates/template.py b/llama_stack/templates/template.py index e4d28d904..6e25aee1d 100644 --- a/llama_stack/templates/template.py +++ b/llama_stack/templates/template.py @@ -117,6 +117,10 @@ class RunConfigSettings(BaseModel): __distro_dir__=f"~/.llama/distributions/{name}", db_name="registry.db", ), + credentials_store=SqliteKVStoreConfig.sample_run_config( + __distro_dir__=f"~/.llama/distributions/{name}", + db_name="credentials.db", + ), models=self.default_models or [], shields=self.default_shields or [], tool_groups=self.default_tool_groups or [], diff --git a/llama_stack/templates/tgi/run-with-safety.yaml b/llama_stack/templates/tgi/run-with-safety.yaml index 3255e9c0b..15d9bd404 100644 --- a/llama_stack/templates/tgi/run-with-safety.yaml +++ b/llama_stack/templates/tgi/run-with-safety.yaml @@ -102,6 +102,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/credentials.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/tgi/run.yaml b/llama_stack/templates/tgi/run.yaml index 179087258..05a321bcd 100644 --- a/llama_stack/templates/tgi/run.yaml +++ b/llama_stack/templates/tgi/run.yaml @@ -101,6 +101,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/credentials.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/together/run-with-safety.yaml b/llama_stack/templates/together/run-with-safety.yaml index fe8c8e397..4c794cf04 100644 --- a/llama_stack/templates/together/run-with-safety.yaml +++ b/llama_stack/templates/together/run-with-safety.yaml @@ -111,6 +111,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/credentials.db models: - metadata: {} model_id: meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo diff --git a/llama_stack/templates/together/run.yaml b/llama_stack/templates/together/run.yaml index b903fc659..d140b10d6 100644 --- a/llama_stack/templates/together/run.yaml +++ b/llama_stack/templates/together/run.yaml @@ -106,6 +106,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/credentials.db models: - metadata: {} model_id: meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo diff --git a/llama_stack/templates/verification/run.yaml b/llama_stack/templates/verification/run.yaml index 11af41da9..07b19277c 100644 --- a/llama_stack/templates/verification/run.yaml +++ b/llama_stack/templates/verification/run.yaml @@ -135,6 +135,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/verification}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/verification}/credentials.db models: - metadata: {} model_id: openai/gpt-4o diff --git a/llama_stack/templates/vllm-gpu/run.yaml b/llama_stack/templates/vllm-gpu/run.yaml index 5d3482528..b7ff93bb2 100644 --- a/llama_stack/templates/vllm-gpu/run.yaml +++ b/llama_stack/templates/vllm-gpu/run.yaml @@ -106,6 +106,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/vllm-gpu}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/vllm-gpu}/credentials.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} diff --git a/llama_stack/templates/watsonx/run.yaml b/llama_stack/templates/watsonx/run.yaml index 8de6a2b6c..8988311c1 100644 --- a/llama_stack/templates/watsonx/run.yaml +++ b/llama_stack/templates/watsonx/run.yaml @@ -103,6 +103,9 @@ providers: metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/registry.db +credentials_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/credentials.db models: - metadata: {} model_id: meta-llama/llama-3-3-70b-instruct diff --git a/tests/integration/tool_runtime/test_mcp.py b/tests/integration/tool_runtime/test_mcp.py new file mode 100644 index 000000000..a086316a6 --- /dev/null +++ b/tests/integration/tool_runtime/test_mcp.py @@ -0,0 +1,178 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import socket +import threading +import time + +import httpx +import mcp.types as types +import pytest +import uvicorn +from llama_stack_client import Agent +from mcp.server.fastmcp import Context, FastMCP +from mcp.server.sse import SseServerTransport +from starlette.applications import Starlette +from starlette.exceptions import HTTPException +from starlette.responses import Response +from starlette.routing import Mount, Route + +AUTH_TOKEN = "test-token" + + +@pytest.fixture(scope="module") +def mcp_server(): + server = FastMCP("FastMCP Test Server") + + @server.tool() + async def greet_everyone( + url: str, ctx: Context + ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: + return [types.TextContent(type="text", text="Hello, world!")] + + sse = SseServerTransport("/messages/") + + async def handle_sse(request): + auth_header = request.headers.get("Authorization") + auth_token = None + if auth_header and auth_header.startswith("Bearer "): + auth_token = auth_header.split(" ")[1] + + if auth_token != AUTH_TOKEN: + raise HTTPException(status_code=401, detail="Unauthorized") + + async with sse.connect_sse(request.scope, request.receive, request._send) as streams: + await server._mcp_server.run( + streams[0], + streams[1], + server._mcp_server.create_initialization_options(), + ) + return Response() + + app = Starlette( + routes=[ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse.handle_post_message), + ], + ) + + def get_open_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + port = get_open_port() + + config = uvicorn.Config(app, host="0.0.0.0", port=port) + server_instance = uvicorn.Server(config) + app.state.uvicorn_server = server_instance + + def run_server(): + server_instance.run() + + # Start the server in a new thread + server_thread = threading.Thread(target=run_server, daemon=True) + server_thread.start() + + # Polling until the server is ready + timeout = 10 + start_time = time.time() + + while time.time() - start_time < timeout: + try: + response = httpx.get(f"http://localhost:{port}/sse") + if response.status_code == 401: + break + except httpx.RequestError: + pass + time.sleep(0.1) + + yield port + + # Tell server to exit + server_instance.should_exit = True + server_thread.join(timeout=5) + + +def test_mcp_invocation(llama_stack_client, mcp_server): + port = mcp_server + test_toolgroup_id = "remote::mcptest" + + # registering itself should fail since it requires listing tools + with pytest.raises(Exception, match="Unauthorized"): + llama_stack_client.toolgroups.register( + toolgroup_id=test_toolgroup_id, + provider_id="model-context-protocol", + mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"), + ) + + llama_stack_client.credentials.create( + provider_id="model-context-protocol", + token_type="access_token", + token=AUTH_TOKEN, + ttl_seconds=100, + ) + llama_stack_client.toolgroups.register( + toolgroup_id=test_toolgroup_id, + provider_id="model-context-protocol", + mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"), + ) + response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id) + assert len(response) == 1 + assert response[0].identifier == "greet_everyone" + assert response[0].type == "tool" + assert len(response[0].parameters) == 1 + p = response[0].parameters[0] + assert p.name == "url" + assert p.parameter_type == "string" + assert p.required + + response = llama_stack_client.tool_runtime.invoke_tool( + tool_name=response[0].identifier, + kwargs=dict(url="https://www.google.com"), + ) + content = response.content + assert len(content) == 1 + assert content[0].type == "text" + assert content[0].text == "Hello, world!" + + models = llama_stack_client.models.list() + model_id = models[0].identifier + print(f"Using model: {model_id}") + agent = Agent( + client=llama_stack_client, + model=model_id, + instructions="You are a helpful assistant.", + tools=[test_toolgroup_id], + ) + session_id = agent.create_session("test-session") + response = agent.create_turn( + session_id=session_id, + messages=[ + { + "role": "user", + "content": "Yo. Use tools.", + } + ], + stream=False, + ) + steps = response.steps + first = steps[0] + assert first.step_type == "inference" + assert len(first.api_model_response.tool_calls) == 1 + tool_call = first.api_model_response.tool_calls[0] + assert tool_call.tool_name == "greet_everyone" + + second = steps[1] + assert second.step_type == "tool_execution" + tool_response_content = second.tool_responses[0].content + assert len(tool_response_content) == 1 + assert tool_response_content[0].type == "text" + assert tool_response_content[0].text == "Hello, world!" + + third = steps[2] + assert third.step_type == "inference" + assert len(third.api_model_response.tool_calls) == 0 diff --git a/tests/integration/tools/test_tools.py b/tests/integration/tools/test_tools.py deleted file mode 100644 index 162669bb4..000000000 --- a/tests/integration/tools/test_tools.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - - -def test_toolsgroups_unregister(llama_stack_client): - client = llama_stack_client - client.toolgroups.unregister( - toolgroup_id="builtin::websearch", - )