From 6e57929edea7682ed69eddaaadbdc826ebdb6808 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 19 May 2025 16:51:03 -0700 Subject: [PATCH] feat: basic implementation and usage of Credentials API for MCP --- llama_stack/distribution/credentials.py | 161 ++++++++++++++++++ llama_stack/distribution/datatypes.py | 8 + llama_stack/distribution/request_headers.py | 19 ++- llama_stack/distribution/server/server.py | 6 +- llama_stack/distribution/stack.py | 34 ++-- .../providers/registry/tool_runtime.py | 1 + .../model_context_protocol/__init__.py | 8 +- .../model_context_protocol.py | 88 ++++++---- 8 files changed, 280 insertions(+), 45 deletions(-) create mode 100644 llama_stack/distribution/credentials.py diff --git a/llama_stack/distribution/credentials.py b/llama_stack/distribution/credentials.py new file mode 100644 index 000000000..5d117bebe --- /dev/null +++ b/llama_stack/distribution/credentials.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import uuid +from datetime import datetime, timedelta, timezone +from typing import Any, Protocol + +from pydantic import BaseModel, Field + +from llama_stack.apis.credentials import CredentialListItem, CredentialTokenType +from llama_stack.apis.credentials import Credentials as CredentialsAPI +from llama_stack.distribution.request_headers import get_logged_in_user +from llama_stack.log import get_logger +from llama_stack.providers.utils.kvstore import KVStore, KVStoreConfig, kvstore_impl + +from .datatypes import Api + +logger = get_logger(__name__, category="core") + + +class AuthenticationRequiredError(Exception): + pass + + +class ProviderCredential(BaseModel): + credential_id: str + provider_id: str + token_type: CredentialTokenType + access_token: str + expires_at: datetime = Field(description="The time at which the credential expires. In UTC.") + refresh_token: str | None = None + + +class CredentialsStore(Protocol): + """This is a private protocol used by the distribution and providers to operate on credentials.""" + + async def read_decrypted_credential(self, provider_id: str) -> str | None: ... + + +class DistributionCredentialsConfig(BaseModel): + # TODO: a kvstore isn't the right primitive because we need to look up + # by both `credential_id` (for delete) and (user, provider_id) for fast look ups + kvstore: KVStoreConfig + + +def get_principal() -> str: + logged_in_user = get_logged_in_user() + if not logged_in_user: + # unauth stack, all users have access to this credential + principal = "*" + else: + principal = logged_in_user + return principal + + +class DistributionCredentialsImpl(CredentialsAPI, CredentialsStore): + def __init__(self, config: DistributionCredentialsConfig, deps: dict[Api, Any]): + self.config = config + self.deps = deps + self.store: KVStore | None = None + + async def initialize(self) -> None: + self.store = await kvstore_impl(self.config.kvstore) + + async def shutdown(self) -> None: + pass + + async def get_credentials(self) -> list[CredentialListItem]: + principal = get_principal() + assert self.store is not None + + credentials = [] + start = f"principal:{principal}/" + end = f"principal:{principal}/\xff" + for value in await self.store.values_in_range(start, end): + if not value: + continue + credential = ProviderCredential(**json.loads(value)) + credentials.append( + CredentialListItem( + credential_id=credential.credential_id, + provider_id=credential.provider_id, + token_type=credential.token_type, + expires_at=credential.expires_at, + ) + ) + return credentials + + async def create_credential( + self, + provider_id: str, + token_type: CredentialTokenType, + token: str, + nonce: str | None = None, + ttl_seconds: int = 3600, + ) -> str: + if token_type == CredentialTokenType.oauth2_authorization_code: + # TODO: we need to exchange the authorization code for an access token + # and store { access_token, refresh_token, expires_at } + raise NotImplementedError("OAuth2 authorization code is not supported yet") + + principal = get_principal() + + # check that provider_id is registered + run_config = self.deps[Api.inspect].run_config + + # TODO: we should make provider_ids unique across all APIs which is not enforced yet + provider_ids = [p.provider_id for p in run_config.providers.values()] + if provider_id not in provider_ids: + raise ValueError(f"Provider {provider_id} is not registered") + + credential_id = str(uuid.uuid4()) + credential = ProviderCredential( + credential_id=credential_id, + provider_id=provider_id, + token_type=token_type, + access_token=token, + expires_at=datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds), + refresh_token=None, + ) + await self.store_credential(principal, credential) + return credential_id + + async def delete_credential(self, credential_id: str) -> None: + principal = get_principal() + assert self.store is not None + + credentials = await self.get_credentials() + for credential in credentials: + if credential.credential_id == credential_id: + await self.store.delete(make_credential_key(principal, credential.provider_id)) + return + raise ValueError(f"Credential {credential_id} not found") + + async def store_credential(self, principal: str, credential: ProviderCredential) -> None: + # TODO: encrypt + key = make_credential_key(principal, credential.provider_id) + assert self.store is not None + await self.store.set(key, credential.model_dump_json()) + + async def read_decrypted_credential(self, provider_id: str) -> str | None: + principal = get_principal() + + key = make_credential_key(principal, provider_id) + assert self.store is not None + value = await self.store.get(key) + if not value: + return None + credential = ProviderCredential(**json.loads(value)) + if credential.expires_at < datetime.now(timezone.utc): + logger.info(f"Credential {credential.credential_id} for {provider_id} has expired") + return None + return credential.access_token + + +def make_credential_key(principal: str, provider_id: str) -> str: + return f"principal:{principal}/provider:{provider_id}" diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 446a88ca0..185b2f380 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -296,6 +296,14 @@ can be instantiated multiple times (with different configs) if necessary. Configuration for the persistence store used by the distribution registry. If not specified, a default SQLite store will be used.""", ) + credentials_store: KVStoreConfig | None = Field( + default=None, + description=""" +Configuration for the persistence store used for store ephemeral per-(user, provider) credentials. This +store is different since it may have different security properties (e.g. not encrypted) and may have different +reliability requirements (e.g. not durable). +""", + ) # registry of "resources" in the distribution models: list[ModelInput] = Field(default_factory=list) diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/distribution/request_headers.py index b03d2dee8..e138205a2 100644 --- a/llama_stack/distribution/request_headers.py +++ b/llama_stack/distribution/request_headers.py @@ -22,11 +22,16 @@ class RequestProviderDataContext(AbstractContextManager): """Context manager for request provider data""" def __init__( - self, provider_data: dict[str, Any] | None = None, auth_attributes: dict[str, list[str]] | None = None + self, + provider_data: dict[str, Any] | None = None, + auth_attributes: dict[str, list[str]] | None = None, + logged_in_user: str | None = None, ): self.provider_data = provider_data or {} if auth_attributes: self.provider_data["__auth_attributes"] = auth_attributes + if logged_in_user: + self.provider_data["__logged_in_user"] = logged_in_user self.token = None @@ -88,11 +93,13 @@ def parse_request_provider_data(headers: dict[str, str]) -> dict[str, Any] | Non def request_provider_data_context( - headers: dict[str, str], auth_attributes: dict[str, list[str]] | None = None + headers: dict[str, str], + auth_attributes: dict[str, list[str]] | None = None, + logged_in_user: str | None = None, ) -> AbstractContextManager: """Context manager that sets request provider data from headers and auth attributes for the duration of the context""" provider_data = parse_request_provider_data(headers) - return RequestProviderDataContext(provider_data, auth_attributes) + return RequestProviderDataContext(provider_data, auth_attributes, logged_in_user) def get_auth_attributes() -> dict[str, list[str]] | None: @@ -101,3 +108,9 @@ def get_auth_attributes() -> dict[str, list[str]] | None: if not provider_data: return None return provider_data.get("__auth_attributes") + + +def get_logged_in_user() -> str | None: + """Helper to retrieve logged in user from the provider data context""" + provider_data = PROVIDER_DATA_VAR.get() + return provider_data.get("__logged_in_user") diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index e25bf0817..bbf8cb456 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -209,11 +209,15 @@ def create_dynamic_typed_route(func: Any, method: str, route: str): async def endpoint(request: Request, **kwargs): # Get auth attributes from the request scope user_attributes = request.scope.get("user_attributes", {}) + principal = request.scope.get("principal", None) await log_request_pre_validation(request) + # TODO: before request execution starts, we need to check for authorization + # so we can send back 40X errors before StreamingResponse starts (and returns a 200). + # Use context manager with both provider data and auth attributes - with request_provider_data_context(request.headers, user_attributes): + with request_provider_data_context(request.headers, user_attributes, principal): is_streaming = is_streaming_request(func.__name__, request, **kwargs) try: diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index df980b21c..c4168e2b1 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -34,6 +34,7 @@ from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime from llama_stack.apis.vector_dbs import VectorDBs from llama_stack.apis.vector_io import VectorIO +from llama_stack.distribution.credentials import DistributionCredentialsConfig, DistributionCredentialsImpl from llama_stack.distribution.datatypes import Provider, StackRunConfig from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl @@ -199,24 +200,30 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]: ) from e -def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None: - """Add internal implementations (inspect and providers) to the implementations dictionary. - - Args: - impls: Dictionary of API implementations - run_config: Stack run configuration - """ +async def instantiate_internal_impls(impls: dict[Api, Any], run_config: StackRunConfig) -> dict[Api, Any]: + """Add internal implementations (inspect, providers, credentials).""" inspect_impl = DistributionInspectImpl( DistributionInspectConfig(run_config=run_config), deps=impls, ) - impls[Api.inspect] = inspect_impl + await inspect_impl.initialize() providers_impl = ProviderImpl( ProviderImplConfig(run_config=run_config), deps=impls, ) - impls[Api.providers] = providers_impl + await providers_impl.initialize() + + credentials_impl = DistributionCredentialsImpl( + DistributionCredentialsConfig(kvstore=run_config.credentials_store), + deps=impls, + ) + await credentials_impl.initialize() + return { + Api.inspect: inspect_impl, + Api.providers: providers_impl, + Api.credentials: credentials_impl, + } # Produces a stack of providers for the given run config. Not all APIs may be @@ -228,7 +235,14 @@ async def construct_stack( impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry) # Add internal implementations after all other providers are resolved - add_internal_implementations(impls, run_config) + internal_impls = await instantiate_internal_impls(impls, run_config) + impls.update(internal_impls) + + # credentials_store = internal_impls[Api.credentials] + # for impl in impls.values(): + # # in an ideal world, we would pass the credentials store as a dependency + # if hasattr(impl, "credentials_store"): + # impl.credentials_store = credentials_store await register_resources(run_config, impls) return impls diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index b9194810e..2e789089b 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -83,5 +83,6 @@ def available_providers() -> list[ProviderSpec]: config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig", pip_packages=["mcp"], ), + api_dependencies=[Api.credentials], ), ] diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py index fb1f558e5..a263f0e9c 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py @@ -4,8 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Any + from pydantic import BaseModel +from llama_stack.apis.datatypes import Api + from .config import ModelContextProtocolConfig @@ -13,9 +17,9 @@ class ModelContextProtocolToolProviderDataValidator(BaseModel): api_key: str -async def get_adapter_impl(config: ModelContextProtocolConfig, _deps): +async def get_adapter_impl(config: ModelContextProtocolConfig, deps: dict[Api, Any]): from .model_context_protocol import ModelContextProtocolToolRuntimeImpl - impl = ModelContextProtocolToolRuntimeImpl(config) + impl = ModelContextProtocolToolRuntimeImpl(config, deps) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index 142730e89..8ac3769d5 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -4,13 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any +from collections.abc import AsyncGenerator +from typing import Any, cast from urllib.parse import urlparse +import exceptiongroup +import httpx from mcp import ClientSession from mcp.client.sse import sse_client from llama_stack.apis.common.content_types import URL +from llama_stack.apis.datatypes import Api from llama_stack.apis.tools import ( ListToolDefsResponse, ToolDef, @@ -18,14 +22,35 @@ from llama_stack.apis.tools import ( ToolParameter, ToolRuntime, ) +from llama_stack.distribution.credentials import AuthenticationRequiredError, CredentialsStore from llama_stack.providers.datatypes import ToolsProtocolPrivate from .config import ModelContextProtocolConfig +async def sse_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, None]: + try: + async with sse_client(endpoint, headers=headers) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + yield session + except BaseException as e: + # TODO: auto-discover auth metadata and cache it, add a nonce, create state + # which can be used to exchange the authorization code for an access token. + if isinstance(e, exceptiongroup.BaseExceptionGroup): + for exc in e.exceptions: + if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401: + raise AuthenticationRequiredError(exc) from exc + elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401: + raise AuthenticationRequiredError(e) from e + else: + raise + + class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): - def __init__(self, config: ModelContextProtocolConfig): + def __init__(self, config: ModelContextProtocolConfig, deps: dict[Api, Any]): self.config = config + self.credentials_store = cast(CredentialsStore, deps[Api.credentials]) async def initialize(self): pass @@ -36,31 +61,30 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): if mcp_endpoint is None: raise ValueError("mcp_endpoint is required") + headers = await self.get_headers() tools = [] - async with sse_client(mcp_endpoint.uri) as streams: - async with ClientSession(*streams) as session: - await session.initialize() - tools_result = await session.list_tools() - for tool in tools_result.tools: - parameters = [] - for param_name, param_schema in tool.inputSchema.get("properties", {}).items(): - parameters.append( - ToolParameter( - name=param_name, - parameter_type=param_schema.get("type", "string"), - description=param_schema.get("description", ""), - ) - ) - tools.append( - ToolDef( - name=tool.name, - description=tool.description, - parameters=parameters, - metadata={ - "endpoint": mcp_endpoint.uri, - }, + async with sse_client_wrapper(mcp_endpoint.uri, headers) as session: + tools_result = await session.list_tools() + for tool in tools_result.tools: + parameters = [] + for param_name, param_schema in tool.inputSchema.get("properties", {}).items(): + parameters.append( + ToolParameter( + name=param_name, + parameter_type=param_schema.get("type", "string"), + description=param_schema.get("description", ""), ) ) + tools.append( + ToolDef( + name=tool.name, + description=tool.description, + parameters=parameters, + metadata={ + "endpoint": mcp_endpoint.uri, + }, + ) + ) return ListToolDefsResponse(data=tools) async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: @@ -71,12 +95,18 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): if urlparse(endpoint).scheme not in ("http", "https"): raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL") - async with sse_client(endpoint) as streams: - async with ClientSession(*streams) as session: - await session.initialize() - result = await session.call_tool(tool.identifier, kwargs) + headers = await self.get_headers() + async with sse_client_wrapper(endpoint, headers) as session: + result = await session.call_tool(tool.identifier, kwargs) return ToolInvocationResult( - content="\n".join([result.model_dump_json() for result in result.content]), + content=[result.model_dump_json() for result in result.content], error_code=1 if result.isError else 0, ) + + async def get_headers(self) -> dict[str, str]: + headers = {} + credentials = await self.credentials_store.get_credential(self.__provider_id__) + if credentials: + headers["Authorization"] = f"Bearer {credentials.token}" + return headers