feat: basic implementation and usage of Credentials API for MCP

This commit is contained in:
Ashwin Bharambe 2025-05-19 16:51:03 -07:00
parent b43cdaaed5
commit 6e57929ede
8 changed files with 280 additions and 45 deletions

View file

@ -0,0 +1,161 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any, Protocol
from pydantic import BaseModel, Field
from llama_stack.apis.credentials import CredentialListItem, CredentialTokenType
from llama_stack.apis.credentials import Credentials as CredentialsAPI
from llama_stack.distribution.request_headers import get_logged_in_user
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore, KVStoreConfig, kvstore_impl
from .datatypes import Api
logger = get_logger(__name__, category="core")
class AuthenticationRequiredError(Exception):
pass
class ProviderCredential(BaseModel):
credential_id: str
provider_id: str
token_type: CredentialTokenType
access_token: str
expires_at: datetime = Field(description="The time at which the credential expires. In UTC.")
refresh_token: str | None = None
class CredentialsStore(Protocol):
"""This is a private protocol used by the distribution and providers to operate on credentials."""
async def read_decrypted_credential(self, provider_id: str) -> str | None: ...
class DistributionCredentialsConfig(BaseModel):
# TODO: a kvstore isn't the right primitive because we need to look up
# by both `credential_id` (for delete) and (user, provider_id) for fast look ups
kvstore: KVStoreConfig
def get_principal() -> str:
logged_in_user = get_logged_in_user()
if not logged_in_user:
# unauth stack, all users have access to this credential
principal = "*"
else:
principal = logged_in_user
return principal
class DistributionCredentialsImpl(CredentialsAPI, CredentialsStore):
def __init__(self, config: DistributionCredentialsConfig, deps: dict[Api, Any]):
self.config = config
self.deps = deps
self.store: KVStore | None = None
async def initialize(self) -> None:
self.store = await kvstore_impl(self.config.kvstore)
async def shutdown(self) -> None:
pass
async def get_credentials(self) -> list[CredentialListItem]:
principal = get_principal()
assert self.store is not None
credentials = []
start = f"principal:{principal}/"
end = f"principal:{principal}/\xff"
for value in await self.store.values_in_range(start, end):
if not value:
continue
credential = ProviderCredential(**json.loads(value))
credentials.append(
CredentialListItem(
credential_id=credential.credential_id,
provider_id=credential.provider_id,
token_type=credential.token_type,
expires_at=credential.expires_at,
)
)
return credentials
async def create_credential(
self,
provider_id: str,
token_type: CredentialTokenType,
token: str,
nonce: str | None = None,
ttl_seconds: int = 3600,
) -> str:
if token_type == CredentialTokenType.oauth2_authorization_code:
# TODO: we need to exchange the authorization code for an access token
# and store { access_token, refresh_token, expires_at }
raise NotImplementedError("OAuth2 authorization code is not supported yet")
principal = get_principal()
# check that provider_id is registered
run_config = self.deps[Api.inspect].run_config
# TODO: we should make provider_ids unique across all APIs which is not enforced yet
provider_ids = [p.provider_id for p in run_config.providers.values()]
if provider_id not in provider_ids:
raise ValueError(f"Provider {provider_id} is not registered")
credential_id = str(uuid.uuid4())
credential = ProviderCredential(
credential_id=credential_id,
provider_id=provider_id,
token_type=token_type,
access_token=token,
expires_at=datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds),
refresh_token=None,
)
await self.store_credential(principal, credential)
return credential_id
async def delete_credential(self, credential_id: str) -> None:
principal = get_principal()
assert self.store is not None
credentials = await self.get_credentials()
for credential in credentials:
if credential.credential_id == credential_id:
await self.store.delete(make_credential_key(principal, credential.provider_id))
return
raise ValueError(f"Credential {credential_id} not found")
async def store_credential(self, principal: str, credential: ProviderCredential) -> None:
# TODO: encrypt
key = make_credential_key(principal, credential.provider_id)
assert self.store is not None
await self.store.set(key, credential.model_dump_json())
async def read_decrypted_credential(self, provider_id: str) -> str | None:
principal = get_principal()
key = make_credential_key(principal, provider_id)
assert self.store is not None
value = await self.store.get(key)
if not value:
return None
credential = ProviderCredential(**json.loads(value))
if credential.expires_at < datetime.now(timezone.utc):
logger.info(f"Credential {credential.credential_id} for {provider_id} has expired")
return None
return credential.access_token
def make_credential_key(principal: str, provider_id: str) -> str:
return f"principal:{principal}/provider:{provider_id}"

View file

@ -296,6 +296,14 @@ can be instantiated multiple times (with different configs) if necessary.
Configuration for the persistence store used by the distribution registry. If not specified,
a default SQLite store will be used.""",
)
credentials_store: KVStoreConfig | None = Field(
default=None,
description="""
Configuration for the persistence store used for store ephemeral per-(user, provider) credentials. This
store is different since it may have different security properties (e.g. not encrypted) and may have different
reliability requirements (e.g. not durable).
""",
)
# registry of "resources" in the distribution
models: list[ModelInput] = Field(default_factory=list)

View file

@ -22,11 +22,16 @@ class RequestProviderDataContext(AbstractContextManager):
"""Context manager for request provider data"""
def __init__(
self, provider_data: dict[str, Any] | None = None, auth_attributes: dict[str, list[str]] | None = None
self,
provider_data: dict[str, Any] | None = None,
auth_attributes: dict[str, list[str]] | None = None,
logged_in_user: str | None = None,
):
self.provider_data = provider_data or {}
if auth_attributes:
self.provider_data["__auth_attributes"] = auth_attributes
if logged_in_user:
self.provider_data["__logged_in_user"] = logged_in_user
self.token = None
@ -88,11 +93,13 @@ def parse_request_provider_data(headers: dict[str, str]) -> dict[str, Any] | Non
def request_provider_data_context(
headers: dict[str, str], auth_attributes: dict[str, list[str]] | None = None
headers: dict[str, str],
auth_attributes: dict[str, list[str]] | None = None,
logged_in_user: str | None = None,
) -> AbstractContextManager:
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
provider_data = parse_request_provider_data(headers)
return RequestProviderDataContext(provider_data, auth_attributes)
return RequestProviderDataContext(provider_data, auth_attributes, logged_in_user)
def get_auth_attributes() -> dict[str, list[str]] | None:
@ -101,3 +108,9 @@ def get_auth_attributes() -> dict[str, list[str]] | None:
if not provider_data:
return None
return provider_data.get("__auth_attributes")
def get_logged_in_user() -> str | None:
"""Helper to retrieve logged in user from the provider data context"""
provider_data = PROVIDER_DATA_VAR.get()
return provider_data.get("__logged_in_user")

View file

@ -209,11 +209,15 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
async def endpoint(request: Request, **kwargs):
# Get auth attributes from the request scope
user_attributes = request.scope.get("user_attributes", {})
principal = request.scope.get("principal", None)
await log_request_pre_validation(request)
# TODO: before request execution starts, we need to check for authorization
# so we can send back 40X errors before StreamingResponse starts (and returns a 200).
# Use context manager with both provider data and auth attributes
with request_provider_data_context(request.headers, user_attributes):
with request_provider_data_context(request.headers, user_attributes, principal):
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try:

View file

@ -34,6 +34,7 @@ from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.credentials import DistributionCredentialsConfig, DistributionCredentialsImpl
from llama_stack.distribution.datatypes import Provider, StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl
@ -199,24 +200,30 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
) from e
def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None:
"""Add internal implementations (inspect and providers) to the implementations dictionary.
Args:
impls: Dictionary of API implementations
run_config: Stack run configuration
"""
async def instantiate_internal_impls(impls: dict[Api, Any], run_config: StackRunConfig) -> dict[Api, Any]:
"""Add internal implementations (inspect, providers, credentials)."""
inspect_impl = DistributionInspectImpl(
DistributionInspectConfig(run_config=run_config),
deps=impls,
)
impls[Api.inspect] = inspect_impl
await inspect_impl.initialize()
providers_impl = ProviderImpl(
ProviderImplConfig(run_config=run_config),
deps=impls,
)
impls[Api.providers] = providers_impl
await providers_impl.initialize()
credentials_impl = DistributionCredentialsImpl(
DistributionCredentialsConfig(kvstore=run_config.credentials_store),
deps=impls,
)
await credentials_impl.initialize()
return {
Api.inspect: inspect_impl,
Api.providers: providers_impl,
Api.credentials: credentials_impl,
}
# Produces a stack of providers for the given run config. Not all APIs may be
@ -228,7 +235,14 @@ async def construct_stack(
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
# Add internal implementations after all other providers are resolved
add_internal_implementations(impls, run_config)
internal_impls = await instantiate_internal_impls(impls, run_config)
impls.update(internal_impls)
# credentials_store = internal_impls[Api.credentials]
# for impl in impls.values():
# # in an ideal world, we would pass the credentials store as a dependency
# if hasattr(impl, "credentials_store"):
# impl.credentials_store = credentials_store
await register_resources(run_config, impls)
return impls

View file

@ -83,5 +83,6 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig",
pip_packages=["mcp"],
),
api_dependencies=[Api.credentials],
),
]

View file

@ -4,8 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel
from llama_stack.apis.datatypes import Api
from .config import ModelContextProtocolConfig
@ -13,9 +17,9 @@ class ModelContextProtocolToolProviderDataValidator(BaseModel):
api_key: str
async def get_adapter_impl(config: ModelContextProtocolConfig, _deps):
async def get_adapter_impl(config: ModelContextProtocolConfig, deps: dict[Api, Any]):
from .model_context_protocol import ModelContextProtocolToolRuntimeImpl
impl = ModelContextProtocolToolRuntimeImpl(config)
impl = ModelContextProtocolToolRuntimeImpl(config, deps)
await impl.initialize()
return impl

View file

@ -4,13 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from collections.abc import AsyncGenerator
from typing import Any, cast
from urllib.parse import urlparse
import exceptiongroup
import httpx
from mcp import ClientSession
from mcp.client.sse import sse_client
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datatypes import Api
from llama_stack.apis.tools import (
ListToolDefsResponse,
ToolDef,
@ -18,14 +22,35 @@ from llama_stack.apis.tools import (
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.credentials import AuthenticationRequiredError, CredentialsStore
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import ModelContextProtocolConfig
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, None]:
try:
async with sse_client(endpoint, headers=headers) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
yield session
except BaseException as e:
# TODO: auto-discover auth metadata and cache it, add a nonce, create state
# which can be used to exchange the authorization code for an access token.
if isinstance(e, exceptiongroup.BaseExceptionGroup):
for exc in e.exceptions:
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401:
raise AuthenticationRequiredError(exc) from exc
elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401:
raise AuthenticationRequiredError(e) from e
else:
raise
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
def __init__(self, config: ModelContextProtocolConfig):
def __init__(self, config: ModelContextProtocolConfig, deps: dict[Api, Any]):
self.config = config
self.credentials_store = cast(CredentialsStore, deps[Api.credentials])
async def initialize(self):
pass
@ -36,31 +61,30 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required")
headers = await self.get_headers()
tools = []
async with sse_client(mcp_endpoint.uri) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
parameters.append(
ToolParameter(
name=param_name,
parameter_type=param_schema.get("type", "string"),
description=param_schema.get("description", ""),
)
)
tools.append(
ToolDef(
name=tool.name,
description=tool.description,
parameters=parameters,
metadata={
"endpoint": mcp_endpoint.uri,
},
async with sse_client_wrapper(mcp_endpoint.uri, headers) as session:
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
parameters.append(
ToolParameter(
name=param_name,
parameter_type=param_schema.get("type", "string"),
description=param_schema.get("description", ""),
)
)
tools.append(
ToolDef(
name=tool.name,
description=tool.description,
parameters=parameters,
metadata={
"endpoint": mcp_endpoint.uri,
},
)
)
return ListToolDefsResponse(data=tools)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
@ -71,12 +95,18 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
if urlparse(endpoint).scheme not in ("http", "https"):
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
async with sse_client(endpoint) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
result = await session.call_tool(tool.identifier, kwargs)
headers = await self.get_headers()
async with sse_client_wrapper(endpoint, headers) as session:
result = await session.call_tool(tool.identifier, kwargs)
return ToolInvocationResult(
content="\n".join([result.model_dump_json() for result in result.content]),
content=[result.model_dump_json() for result in result.content],
error_code=1 if result.isError else 0,
)
async def get_headers(self) -> dict[str, str]:
headers = {}
credentials = await self.credentials_store.get_credential(self.__provider_id__)
if credentials:
headers["Authorization"] = f"Bearer {credentials.token}"
return headers