feat: basic implementation and usage of Credentials API for MCP

This commit is contained in:
Ashwin Bharambe 2025-05-19 16:51:03 -07:00
parent b43cdaaed5
commit 6e57929ede
8 changed files with 280 additions and 45 deletions

View file

@ -0,0 +1,161 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any, Protocol
from pydantic import BaseModel, Field
from llama_stack.apis.credentials import CredentialListItem, CredentialTokenType
from llama_stack.apis.credentials import Credentials as CredentialsAPI
from llama_stack.distribution.request_headers import get_logged_in_user
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore, KVStoreConfig, kvstore_impl
from .datatypes import Api
logger = get_logger(__name__, category="core")
class AuthenticationRequiredError(Exception):
pass
class ProviderCredential(BaseModel):
credential_id: str
provider_id: str
token_type: CredentialTokenType
access_token: str
expires_at: datetime = Field(description="The time at which the credential expires. In UTC.")
refresh_token: str | None = None
class CredentialsStore(Protocol):
"""This is a private protocol used by the distribution and providers to operate on credentials."""
async def read_decrypted_credential(self, provider_id: str) -> str | None: ...
class DistributionCredentialsConfig(BaseModel):
# TODO: a kvstore isn't the right primitive because we need to look up
# by both `credential_id` (for delete) and (user, provider_id) for fast look ups
kvstore: KVStoreConfig
def get_principal() -> str:
logged_in_user = get_logged_in_user()
if not logged_in_user:
# unauth stack, all users have access to this credential
principal = "*"
else:
principal = logged_in_user
return principal
class DistributionCredentialsImpl(CredentialsAPI, CredentialsStore):
def __init__(self, config: DistributionCredentialsConfig, deps: dict[Api, Any]):
self.config = config
self.deps = deps
self.store: KVStore | None = None
async def initialize(self) -> None:
self.store = await kvstore_impl(self.config.kvstore)
async def shutdown(self) -> None:
pass
async def get_credentials(self) -> list[CredentialListItem]:
principal = get_principal()
assert self.store is not None
credentials = []
start = f"principal:{principal}/"
end = f"principal:{principal}/\xff"
for value in await self.store.values_in_range(start, end):
if not value:
continue
credential = ProviderCredential(**json.loads(value))
credentials.append(
CredentialListItem(
credential_id=credential.credential_id,
provider_id=credential.provider_id,
token_type=credential.token_type,
expires_at=credential.expires_at,
)
)
return credentials
async def create_credential(
self,
provider_id: str,
token_type: CredentialTokenType,
token: str,
nonce: str | None = None,
ttl_seconds: int = 3600,
) -> str:
if token_type == CredentialTokenType.oauth2_authorization_code:
# TODO: we need to exchange the authorization code for an access token
# and store { access_token, refresh_token, expires_at }
raise NotImplementedError("OAuth2 authorization code is not supported yet")
principal = get_principal()
# check that provider_id is registered
run_config = self.deps[Api.inspect].run_config
# TODO: we should make provider_ids unique across all APIs which is not enforced yet
provider_ids = [p.provider_id for p in run_config.providers.values()]
if provider_id not in provider_ids:
raise ValueError(f"Provider {provider_id} is not registered")
credential_id = str(uuid.uuid4())
credential = ProviderCredential(
credential_id=credential_id,
provider_id=provider_id,
token_type=token_type,
access_token=token,
expires_at=datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds),
refresh_token=None,
)
await self.store_credential(principal, credential)
return credential_id
async def delete_credential(self, credential_id: str) -> None:
principal = get_principal()
assert self.store is not None
credentials = await self.get_credentials()
for credential in credentials:
if credential.credential_id == credential_id:
await self.store.delete(make_credential_key(principal, credential.provider_id))
return
raise ValueError(f"Credential {credential_id} not found")
async def store_credential(self, principal: str, credential: ProviderCredential) -> None:
# TODO: encrypt
key = make_credential_key(principal, credential.provider_id)
assert self.store is not None
await self.store.set(key, credential.model_dump_json())
async def read_decrypted_credential(self, provider_id: str) -> str | None:
principal = get_principal()
key = make_credential_key(principal, provider_id)
assert self.store is not None
value = await self.store.get(key)
if not value:
return None
credential = ProviderCredential(**json.loads(value))
if credential.expires_at < datetime.now(timezone.utc):
logger.info(f"Credential {credential.credential_id} for {provider_id} has expired")
return None
return credential.access_token
def make_credential_key(principal: str, provider_id: str) -> str:
return f"principal:{principal}/provider:{provider_id}"

View file

@ -296,6 +296,14 @@ can be instantiated multiple times (with different configs) if necessary.
Configuration for the persistence store used by the distribution registry. If not specified,
a default SQLite store will be used.""",
)
credentials_store: KVStoreConfig | None = Field(
default=None,
description="""
Configuration for the persistence store used for store ephemeral per-(user, provider) credentials. This
store is different since it may have different security properties (e.g. not encrypted) and may have different
reliability requirements (e.g. not durable).
""",
)
# registry of "resources" in the distribution
models: list[ModelInput] = Field(default_factory=list)

View file

@ -22,11 +22,16 @@ class RequestProviderDataContext(AbstractContextManager):
"""Context manager for request provider data"""
def __init__(
self, provider_data: dict[str, Any] | None = None, auth_attributes: dict[str, list[str]] | None = None
self,
provider_data: dict[str, Any] | None = None,
auth_attributes: dict[str, list[str]] | None = None,
logged_in_user: str | None = None,
):
self.provider_data = provider_data or {}
if auth_attributes:
self.provider_data["__auth_attributes"] = auth_attributes
if logged_in_user:
self.provider_data["__logged_in_user"] = logged_in_user
self.token = None
@ -88,11 +93,13 @@ def parse_request_provider_data(headers: dict[str, str]) -> dict[str, Any] | Non
def request_provider_data_context(
headers: dict[str, str], auth_attributes: dict[str, list[str]] | None = None
headers: dict[str, str],
auth_attributes: dict[str, list[str]] | None = None,
logged_in_user: str | None = None,
) -> AbstractContextManager:
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
provider_data = parse_request_provider_data(headers)
return RequestProviderDataContext(provider_data, auth_attributes)
return RequestProviderDataContext(provider_data, auth_attributes, logged_in_user)
def get_auth_attributes() -> dict[str, list[str]] | None:
@ -101,3 +108,9 @@ def get_auth_attributes() -> dict[str, list[str]] | None:
if not provider_data:
return None
return provider_data.get("__auth_attributes")
def get_logged_in_user() -> str | None:
"""Helper to retrieve logged in user from the provider data context"""
provider_data = PROVIDER_DATA_VAR.get()
return provider_data.get("__logged_in_user")

View file

@ -209,11 +209,15 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
async def endpoint(request: Request, **kwargs):
# Get auth attributes from the request scope
user_attributes = request.scope.get("user_attributes", {})
principal = request.scope.get("principal", None)
await log_request_pre_validation(request)
# TODO: before request execution starts, we need to check for authorization
# so we can send back 40X errors before StreamingResponse starts (and returns a 200).
# Use context manager with both provider data and auth attributes
with request_provider_data_context(request.headers, user_attributes):
with request_provider_data_context(request.headers, user_attributes, principal):
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try:

View file

@ -34,6 +34,7 @@ from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.credentials import DistributionCredentialsConfig, DistributionCredentialsImpl
from llama_stack.distribution.datatypes import Provider, StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl
@ -199,24 +200,30 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
) from e
def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None:
"""Add internal implementations (inspect and providers) to the implementations dictionary.
Args:
impls: Dictionary of API implementations
run_config: Stack run configuration
"""
async def instantiate_internal_impls(impls: dict[Api, Any], run_config: StackRunConfig) -> dict[Api, Any]:
"""Add internal implementations (inspect, providers, credentials)."""
inspect_impl = DistributionInspectImpl(
DistributionInspectConfig(run_config=run_config),
deps=impls,
)
impls[Api.inspect] = inspect_impl
await inspect_impl.initialize()
providers_impl = ProviderImpl(
ProviderImplConfig(run_config=run_config),
deps=impls,
)
impls[Api.providers] = providers_impl
await providers_impl.initialize()
credentials_impl = DistributionCredentialsImpl(
DistributionCredentialsConfig(kvstore=run_config.credentials_store),
deps=impls,
)
await credentials_impl.initialize()
return {
Api.inspect: inspect_impl,
Api.providers: providers_impl,
Api.credentials: credentials_impl,
}
# Produces a stack of providers for the given run config. Not all APIs may be
@ -228,7 +235,14 @@ async def construct_stack(
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
# Add internal implementations after all other providers are resolved
add_internal_implementations(impls, run_config)
internal_impls = await instantiate_internal_impls(impls, run_config)
impls.update(internal_impls)
# credentials_store = internal_impls[Api.credentials]
# for impl in impls.values():
# # in an ideal world, we would pass the credentials store as a dependency
# if hasattr(impl, "credentials_store"):
# impl.credentials_store = credentials_store
await register_resources(run_config, impls)
return impls