mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
feat: basic implementation and usage of Credentials API for MCP
This commit is contained in:
parent
b43cdaaed5
commit
6e57929ede
8 changed files with 280 additions and 45 deletions
161
llama_stack/distribution/credentials.py
Normal file
161
llama_stack/distribution/credentials.py
Normal file
|
@ -0,0 +1,161 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.credentials import CredentialListItem, CredentialTokenType
|
||||
from llama_stack.apis.credentials import Credentials as CredentialsAPI
|
||||
from llama_stack.distribution.request_headers import get_logged_in_user
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore, KVStoreConfig, kvstore_impl
|
||||
|
||||
from .datatypes import Api
|
||||
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
class AuthenticationRequiredError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ProviderCredential(BaseModel):
|
||||
credential_id: str
|
||||
provider_id: str
|
||||
token_type: CredentialTokenType
|
||||
access_token: str
|
||||
expires_at: datetime = Field(description="The time at which the credential expires. In UTC.")
|
||||
refresh_token: str | None = None
|
||||
|
||||
|
||||
class CredentialsStore(Protocol):
|
||||
"""This is a private protocol used by the distribution and providers to operate on credentials."""
|
||||
|
||||
async def read_decrypted_credential(self, provider_id: str) -> str | None: ...
|
||||
|
||||
|
||||
class DistributionCredentialsConfig(BaseModel):
|
||||
# TODO: a kvstore isn't the right primitive because we need to look up
|
||||
# by both `credential_id` (for delete) and (user, provider_id) for fast look ups
|
||||
kvstore: KVStoreConfig
|
||||
|
||||
|
||||
def get_principal() -> str:
|
||||
logged_in_user = get_logged_in_user()
|
||||
if not logged_in_user:
|
||||
# unauth stack, all users have access to this credential
|
||||
principal = "*"
|
||||
else:
|
||||
principal = logged_in_user
|
||||
return principal
|
||||
|
||||
|
||||
class DistributionCredentialsImpl(CredentialsAPI, CredentialsStore):
|
||||
def __init__(self, config: DistributionCredentialsConfig, deps: dict[Api, Any]):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
self.store: KVStore | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.store = await kvstore_impl(self.config.kvstore)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def get_credentials(self) -> list[CredentialListItem]:
|
||||
principal = get_principal()
|
||||
assert self.store is not None
|
||||
|
||||
credentials = []
|
||||
start = f"principal:{principal}/"
|
||||
end = f"principal:{principal}/\xff"
|
||||
for value in await self.store.values_in_range(start, end):
|
||||
if not value:
|
||||
continue
|
||||
credential = ProviderCredential(**json.loads(value))
|
||||
credentials.append(
|
||||
CredentialListItem(
|
||||
credential_id=credential.credential_id,
|
||||
provider_id=credential.provider_id,
|
||||
token_type=credential.token_type,
|
||||
expires_at=credential.expires_at,
|
||||
)
|
||||
)
|
||||
return credentials
|
||||
|
||||
async def create_credential(
|
||||
self,
|
||||
provider_id: str,
|
||||
token_type: CredentialTokenType,
|
||||
token: str,
|
||||
nonce: str | None = None,
|
||||
ttl_seconds: int = 3600,
|
||||
) -> str:
|
||||
if token_type == CredentialTokenType.oauth2_authorization_code:
|
||||
# TODO: we need to exchange the authorization code for an access token
|
||||
# and store { access_token, refresh_token, expires_at }
|
||||
raise NotImplementedError("OAuth2 authorization code is not supported yet")
|
||||
|
||||
principal = get_principal()
|
||||
|
||||
# check that provider_id is registered
|
||||
run_config = self.deps[Api.inspect].run_config
|
||||
|
||||
# TODO: we should make provider_ids unique across all APIs which is not enforced yet
|
||||
provider_ids = [p.provider_id for p in run_config.providers.values()]
|
||||
if provider_id not in provider_ids:
|
||||
raise ValueError(f"Provider {provider_id} is not registered")
|
||||
|
||||
credential_id = str(uuid.uuid4())
|
||||
credential = ProviderCredential(
|
||||
credential_id=credential_id,
|
||||
provider_id=provider_id,
|
||||
token_type=token_type,
|
||||
access_token=token,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds),
|
||||
refresh_token=None,
|
||||
)
|
||||
await self.store_credential(principal, credential)
|
||||
return credential_id
|
||||
|
||||
async def delete_credential(self, credential_id: str) -> None:
|
||||
principal = get_principal()
|
||||
assert self.store is not None
|
||||
|
||||
credentials = await self.get_credentials()
|
||||
for credential in credentials:
|
||||
if credential.credential_id == credential_id:
|
||||
await self.store.delete(make_credential_key(principal, credential.provider_id))
|
||||
return
|
||||
raise ValueError(f"Credential {credential_id} not found")
|
||||
|
||||
async def store_credential(self, principal: str, credential: ProviderCredential) -> None:
|
||||
# TODO: encrypt
|
||||
key = make_credential_key(principal, credential.provider_id)
|
||||
assert self.store is not None
|
||||
await self.store.set(key, credential.model_dump_json())
|
||||
|
||||
async def read_decrypted_credential(self, provider_id: str) -> str | None:
|
||||
principal = get_principal()
|
||||
|
||||
key = make_credential_key(principal, provider_id)
|
||||
assert self.store is not None
|
||||
value = await self.store.get(key)
|
||||
if not value:
|
||||
return None
|
||||
credential = ProviderCredential(**json.loads(value))
|
||||
if credential.expires_at < datetime.now(timezone.utc):
|
||||
logger.info(f"Credential {credential.credential_id} for {provider_id} has expired")
|
||||
return None
|
||||
return credential.access_token
|
||||
|
||||
|
||||
def make_credential_key(principal: str, provider_id: str) -> str:
|
||||
return f"principal:{principal}/provider:{provider_id}"
|
|
@ -296,6 +296,14 @@ can be instantiated multiple times (with different configs) if necessary.
|
|||
Configuration for the persistence store used by the distribution registry. If not specified,
|
||||
a default SQLite store will be used.""",
|
||||
)
|
||||
credentials_store: KVStoreConfig | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Configuration for the persistence store used for store ephemeral per-(user, provider) credentials. This
|
||||
store is different since it may have different security properties (e.g. not encrypted) and may have different
|
||||
reliability requirements (e.g. not durable).
|
||||
""",
|
||||
)
|
||||
|
||||
# registry of "resources" in the distribution
|
||||
models: list[ModelInput] = Field(default_factory=list)
|
||||
|
|
|
@ -22,11 +22,16 @@ class RequestProviderDataContext(AbstractContextManager):
|
|||
"""Context manager for request provider data"""
|
||||
|
||||
def __init__(
|
||||
self, provider_data: dict[str, Any] | None = None, auth_attributes: dict[str, list[str]] | None = None
|
||||
self,
|
||||
provider_data: dict[str, Any] | None = None,
|
||||
auth_attributes: dict[str, list[str]] | None = None,
|
||||
logged_in_user: str | None = None,
|
||||
):
|
||||
self.provider_data = provider_data or {}
|
||||
if auth_attributes:
|
||||
self.provider_data["__auth_attributes"] = auth_attributes
|
||||
if logged_in_user:
|
||||
self.provider_data["__logged_in_user"] = logged_in_user
|
||||
|
||||
self.token = None
|
||||
|
||||
|
@ -88,11 +93,13 @@ def parse_request_provider_data(headers: dict[str, str]) -> dict[str, Any] | Non
|
|||
|
||||
|
||||
def request_provider_data_context(
|
||||
headers: dict[str, str], auth_attributes: dict[str, list[str]] | None = None
|
||||
headers: dict[str, str],
|
||||
auth_attributes: dict[str, list[str]] | None = None,
|
||||
logged_in_user: str | None = None,
|
||||
) -> AbstractContextManager:
|
||||
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
|
||||
provider_data = parse_request_provider_data(headers)
|
||||
return RequestProviderDataContext(provider_data, auth_attributes)
|
||||
return RequestProviderDataContext(provider_data, auth_attributes, logged_in_user)
|
||||
|
||||
|
||||
def get_auth_attributes() -> dict[str, list[str]] | None:
|
||||
|
@ -101,3 +108,9 @@ def get_auth_attributes() -> dict[str, list[str]] | None:
|
|||
if not provider_data:
|
||||
return None
|
||||
return provider_data.get("__auth_attributes")
|
||||
|
||||
|
||||
def get_logged_in_user() -> str | None:
|
||||
"""Helper to retrieve logged in user from the provider data context"""
|
||||
provider_data = PROVIDER_DATA_VAR.get()
|
||||
return provider_data.get("__logged_in_user")
|
||||
|
|
|
@ -209,11 +209,15 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
|||
async def endpoint(request: Request, **kwargs):
|
||||
# Get auth attributes from the request scope
|
||||
user_attributes = request.scope.get("user_attributes", {})
|
||||
principal = request.scope.get("principal", None)
|
||||
|
||||
await log_request_pre_validation(request)
|
||||
|
||||
# TODO: before request execution starts, we need to check for authorization
|
||||
# so we can send back 40X errors before StreamingResponse starts (and returns a 200).
|
||||
|
||||
# Use context manager with both provider data and auth attributes
|
||||
with request_provider_data_context(request.headers, user_attributes):
|
||||
with request_provider_data_context(request.headers, user_attributes, principal):
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
|
||||
try:
|
||||
|
|
|
@ -34,6 +34,7 @@ from llama_stack.apis.telemetry import Telemetry
|
|||
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDBs
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.distribution.credentials import DistributionCredentialsConfig, DistributionCredentialsImpl
|
||||
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||
|
@ -199,24 +200,30 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
|
|||
) from e
|
||||
|
||||
|
||||
def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None:
|
||||
"""Add internal implementations (inspect and providers) to the implementations dictionary.
|
||||
|
||||
Args:
|
||||
impls: Dictionary of API implementations
|
||||
run_config: Stack run configuration
|
||||
"""
|
||||
async def instantiate_internal_impls(impls: dict[Api, Any], run_config: StackRunConfig) -> dict[Api, Any]:
|
||||
"""Add internal implementations (inspect, providers, credentials)."""
|
||||
inspect_impl = DistributionInspectImpl(
|
||||
DistributionInspectConfig(run_config=run_config),
|
||||
deps=impls,
|
||||
)
|
||||
impls[Api.inspect] = inspect_impl
|
||||
await inspect_impl.initialize()
|
||||
|
||||
providers_impl = ProviderImpl(
|
||||
ProviderImplConfig(run_config=run_config),
|
||||
deps=impls,
|
||||
)
|
||||
impls[Api.providers] = providers_impl
|
||||
await providers_impl.initialize()
|
||||
|
||||
credentials_impl = DistributionCredentialsImpl(
|
||||
DistributionCredentialsConfig(kvstore=run_config.credentials_store),
|
||||
deps=impls,
|
||||
)
|
||||
await credentials_impl.initialize()
|
||||
return {
|
||||
Api.inspect: inspect_impl,
|
||||
Api.providers: providers_impl,
|
||||
Api.credentials: credentials_impl,
|
||||
}
|
||||
|
||||
|
||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||
|
@ -228,7 +235,14 @@ async def construct_stack(
|
|||
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
|
||||
|
||||
# Add internal implementations after all other providers are resolved
|
||||
add_internal_implementations(impls, run_config)
|
||||
internal_impls = await instantiate_internal_impls(impls, run_config)
|
||||
impls.update(internal_impls)
|
||||
|
||||
# credentials_store = internal_impls[Api.credentials]
|
||||
# for impl in impls.values():
|
||||
# # in an ideal world, we would pass the credentials store as a dependency
|
||||
# if hasattr(impl, "credentials_store"):
|
||||
# impl.credentials_store = credentials_store
|
||||
|
||||
await register_resources(run_config, impls)
|
||||
return impls
|
||||
|
|
|
@ -83,5 +83,6 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig",
|
||||
pip_packages=["mcp"],
|
||||
),
|
||||
api_dependencies=[Api.credentials],
|
||||
),
|
||||
]
|
||||
|
|
|
@ -4,8 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.datatypes import Api
|
||||
|
||||
from .config import ModelContextProtocolConfig
|
||||
|
||||
|
||||
|
@ -13,9 +17,9 @@ class ModelContextProtocolToolProviderDataValidator(BaseModel):
|
|||
api_key: str
|
||||
|
||||
|
||||
async def get_adapter_impl(config: ModelContextProtocolConfig, _deps):
|
||||
async def get_adapter_impl(config: ModelContextProtocolConfig, deps: dict[Api, Any]):
|
||||
from .model_context_protocol import ModelContextProtocolToolRuntimeImpl
|
||||
|
||||
impl = ModelContextProtocolToolRuntimeImpl(config)
|
||||
impl = ModelContextProtocolToolRuntimeImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -4,13 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import exceptiongroup
|
||||
import httpx
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
ToolDef,
|
||||
|
@ -18,14 +22,35 @@ from llama_stack.apis.tools import (
|
|||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.distribution.credentials import AuthenticationRequiredError, CredentialsStore
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
||||
from .config import ModelContextProtocolConfig
|
||||
|
||||
|
||||
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, None]:
|
||||
try:
|
||||
async with sse_client(endpoint, headers=headers) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
yield session
|
||||
except BaseException as e:
|
||||
# TODO: auto-discover auth metadata and cache it, add a nonce, create state
|
||||
# which can be used to exchange the authorization code for an access token.
|
||||
if isinstance(e, exceptiongroup.BaseExceptionGroup):
|
||||
for exc in e.exceptions:
|
||||
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401:
|
||||
raise AuthenticationRequiredError(exc) from exc
|
||||
elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401:
|
||||
raise AuthenticationRequiredError(e) from e
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||
def __init__(self, config: ModelContextProtocolConfig):
|
||||
def __init__(self, config: ModelContextProtocolConfig, deps: dict[Api, Any]):
|
||||
self.config = config
|
||||
self.credentials_store = cast(CredentialsStore, deps[Api.credentials])
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
@ -36,31 +61,30 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
if mcp_endpoint is None:
|
||||
raise ValueError("mcp_endpoint is required")
|
||||
|
||||
headers = await self.get_headers()
|
||||
tools = []
|
||||
async with sse_client(mcp_endpoint.uri) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
tools_result = await session.list_tools()
|
||||
for tool in tools_result.tools:
|
||||
parameters = []
|
||||
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name=param_name,
|
||||
parameter_type=param_schema.get("type", "string"),
|
||||
description=param_schema.get("description", ""),
|
||||
)
|
||||
)
|
||||
tools.append(
|
||||
ToolDef(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=parameters,
|
||||
metadata={
|
||||
"endpoint": mcp_endpoint.uri,
|
||||
},
|
||||
async with sse_client_wrapper(mcp_endpoint.uri, headers) as session:
|
||||
tools_result = await session.list_tools()
|
||||
for tool in tools_result.tools:
|
||||
parameters = []
|
||||
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name=param_name,
|
||||
parameter_type=param_schema.get("type", "string"),
|
||||
description=param_schema.get("description", ""),
|
||||
)
|
||||
)
|
||||
tools.append(
|
||||
ToolDef(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=parameters,
|
||||
metadata={
|
||||
"endpoint": mcp_endpoint.uri,
|
||||
},
|
||||
)
|
||||
)
|
||||
return ListToolDefsResponse(data=tools)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
|
@ -71,12 +95,18 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
if urlparse(endpoint).scheme not in ("http", "https"):
|
||||
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
|
||||
|
||||
async with sse_client(endpoint) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
result = await session.call_tool(tool.identifier, kwargs)
|
||||
headers = await self.get_headers()
|
||||
async with sse_client_wrapper(endpoint, headers) as session:
|
||||
result = await session.call_tool(tool.identifier, kwargs)
|
||||
|
||||
return ToolInvocationResult(
|
||||
content="\n".join([result.model_dump_json() for result in result.content]),
|
||||
content=[result.model_dump_json() for result in result.content],
|
||||
error_code=1 if result.isError else 0,
|
||||
)
|
||||
|
||||
async def get_headers(self) -> dict[str, str]:
|
||||
headers = {}
|
||||
credentials = await self.credentials_store.get_credential(self.__provider_id__)
|
||||
if credentials:
|
||||
headers["Authorization"] = f"Bearer {credentials.token}"
|
||||
return headers
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue