mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
feat: basic implementation and usage of Credentials API for MCP
This commit is contained in:
parent
b43cdaaed5
commit
6e57929ede
8 changed files with 280 additions and 45 deletions
161
llama_stack/distribution/credentials.py
Normal file
161
llama_stack/distribution/credentials.py
Normal file
|
@ -0,0 +1,161 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.apis.credentials import CredentialListItem, CredentialTokenType
|
||||||
|
from llama_stack.apis.credentials import Credentials as CredentialsAPI
|
||||||
|
from llama_stack.distribution.request_headers import get_logged_in_user
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.utils.kvstore import KVStore, KVStoreConfig, kvstore_impl
|
||||||
|
|
||||||
|
from .datatypes import Api
|
||||||
|
|
||||||
|
logger = get_logger(__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationRequiredError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderCredential(BaseModel):
|
||||||
|
credential_id: str
|
||||||
|
provider_id: str
|
||||||
|
token_type: CredentialTokenType
|
||||||
|
access_token: str
|
||||||
|
expires_at: datetime = Field(description="The time at which the credential expires. In UTC.")
|
||||||
|
refresh_token: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class CredentialsStore(Protocol):
|
||||||
|
"""This is a private protocol used by the distribution and providers to operate on credentials."""
|
||||||
|
|
||||||
|
async def read_decrypted_credential(self, provider_id: str) -> str | None: ...
|
||||||
|
|
||||||
|
|
||||||
|
class DistributionCredentialsConfig(BaseModel):
|
||||||
|
# TODO: a kvstore isn't the right primitive because we need to look up
|
||||||
|
# by both `credential_id` (for delete) and (user, provider_id) for fast look ups
|
||||||
|
kvstore: KVStoreConfig
|
||||||
|
|
||||||
|
|
||||||
|
def get_principal() -> str:
|
||||||
|
logged_in_user = get_logged_in_user()
|
||||||
|
if not logged_in_user:
|
||||||
|
# unauth stack, all users have access to this credential
|
||||||
|
principal = "*"
|
||||||
|
else:
|
||||||
|
principal = logged_in_user
|
||||||
|
return principal
|
||||||
|
|
||||||
|
|
||||||
|
class DistributionCredentialsImpl(CredentialsAPI, CredentialsStore):
|
||||||
|
def __init__(self, config: DistributionCredentialsConfig, deps: dict[Api, Any]):
|
||||||
|
self.config = config
|
||||||
|
self.deps = deps
|
||||||
|
self.store: KVStore | None = None
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
self.store = await kvstore_impl(self.config.kvstore)
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def get_credentials(self) -> list[CredentialListItem]:
|
||||||
|
principal = get_principal()
|
||||||
|
assert self.store is not None
|
||||||
|
|
||||||
|
credentials = []
|
||||||
|
start = f"principal:{principal}/"
|
||||||
|
end = f"principal:{principal}/\xff"
|
||||||
|
for value in await self.store.values_in_range(start, end):
|
||||||
|
if not value:
|
||||||
|
continue
|
||||||
|
credential = ProviderCredential(**json.loads(value))
|
||||||
|
credentials.append(
|
||||||
|
CredentialListItem(
|
||||||
|
credential_id=credential.credential_id,
|
||||||
|
provider_id=credential.provider_id,
|
||||||
|
token_type=credential.token_type,
|
||||||
|
expires_at=credential.expires_at,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return credentials
|
||||||
|
|
||||||
|
async def create_credential(
|
||||||
|
self,
|
||||||
|
provider_id: str,
|
||||||
|
token_type: CredentialTokenType,
|
||||||
|
token: str,
|
||||||
|
nonce: str | None = None,
|
||||||
|
ttl_seconds: int = 3600,
|
||||||
|
) -> str:
|
||||||
|
if token_type == CredentialTokenType.oauth2_authorization_code:
|
||||||
|
# TODO: we need to exchange the authorization code for an access token
|
||||||
|
# and store { access_token, refresh_token, expires_at }
|
||||||
|
raise NotImplementedError("OAuth2 authorization code is not supported yet")
|
||||||
|
|
||||||
|
principal = get_principal()
|
||||||
|
|
||||||
|
# check that provider_id is registered
|
||||||
|
run_config = self.deps[Api.inspect].run_config
|
||||||
|
|
||||||
|
# TODO: we should make provider_ids unique across all APIs which is not enforced yet
|
||||||
|
provider_ids = [p.provider_id for p in run_config.providers.values()]
|
||||||
|
if provider_id not in provider_ids:
|
||||||
|
raise ValueError(f"Provider {provider_id} is not registered")
|
||||||
|
|
||||||
|
credential_id = str(uuid.uuid4())
|
||||||
|
credential = ProviderCredential(
|
||||||
|
credential_id=credential_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
token_type=token_type,
|
||||||
|
access_token=token,
|
||||||
|
expires_at=datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds),
|
||||||
|
refresh_token=None,
|
||||||
|
)
|
||||||
|
await self.store_credential(principal, credential)
|
||||||
|
return credential_id
|
||||||
|
|
||||||
|
async def delete_credential(self, credential_id: str) -> None:
|
||||||
|
principal = get_principal()
|
||||||
|
assert self.store is not None
|
||||||
|
|
||||||
|
credentials = await self.get_credentials()
|
||||||
|
for credential in credentials:
|
||||||
|
if credential.credential_id == credential_id:
|
||||||
|
await self.store.delete(make_credential_key(principal, credential.provider_id))
|
||||||
|
return
|
||||||
|
raise ValueError(f"Credential {credential_id} not found")
|
||||||
|
|
||||||
|
async def store_credential(self, principal: str, credential: ProviderCredential) -> None:
|
||||||
|
# TODO: encrypt
|
||||||
|
key = make_credential_key(principal, credential.provider_id)
|
||||||
|
assert self.store is not None
|
||||||
|
await self.store.set(key, credential.model_dump_json())
|
||||||
|
|
||||||
|
async def read_decrypted_credential(self, provider_id: str) -> str | None:
|
||||||
|
principal = get_principal()
|
||||||
|
|
||||||
|
key = make_credential_key(principal, provider_id)
|
||||||
|
assert self.store is not None
|
||||||
|
value = await self.store.get(key)
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
credential = ProviderCredential(**json.loads(value))
|
||||||
|
if credential.expires_at < datetime.now(timezone.utc):
|
||||||
|
logger.info(f"Credential {credential.credential_id} for {provider_id} has expired")
|
||||||
|
return None
|
||||||
|
return credential.access_token
|
||||||
|
|
||||||
|
|
||||||
|
def make_credential_key(principal: str, provider_id: str) -> str:
|
||||||
|
return f"principal:{principal}/provider:{provider_id}"
|
|
@ -296,6 +296,14 @@ can be instantiated multiple times (with different configs) if necessary.
|
||||||
Configuration for the persistence store used by the distribution registry. If not specified,
|
Configuration for the persistence store used by the distribution registry. If not specified,
|
||||||
a default SQLite store will be used.""",
|
a default SQLite store will be used.""",
|
||||||
)
|
)
|
||||||
|
credentials_store: KVStoreConfig | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="""
|
||||||
|
Configuration for the persistence store used for store ephemeral per-(user, provider) credentials. This
|
||||||
|
store is different since it may have different security properties (e.g. not encrypted) and may have different
|
||||||
|
reliability requirements (e.g. not durable).
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
# registry of "resources" in the distribution
|
# registry of "resources" in the distribution
|
||||||
models: list[ModelInput] = Field(default_factory=list)
|
models: list[ModelInput] = Field(default_factory=list)
|
||||||
|
|
|
@ -22,11 +22,16 @@ class RequestProviderDataContext(AbstractContextManager):
|
||||||
"""Context manager for request provider data"""
|
"""Context manager for request provider data"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, provider_data: dict[str, Any] | None = None, auth_attributes: dict[str, list[str]] | None = None
|
self,
|
||||||
|
provider_data: dict[str, Any] | None = None,
|
||||||
|
auth_attributes: dict[str, list[str]] | None = None,
|
||||||
|
logged_in_user: str | None = None,
|
||||||
):
|
):
|
||||||
self.provider_data = provider_data or {}
|
self.provider_data = provider_data or {}
|
||||||
if auth_attributes:
|
if auth_attributes:
|
||||||
self.provider_data["__auth_attributes"] = auth_attributes
|
self.provider_data["__auth_attributes"] = auth_attributes
|
||||||
|
if logged_in_user:
|
||||||
|
self.provider_data["__logged_in_user"] = logged_in_user
|
||||||
|
|
||||||
self.token = None
|
self.token = None
|
||||||
|
|
||||||
|
@ -88,11 +93,13 @@ def parse_request_provider_data(headers: dict[str, str]) -> dict[str, Any] | Non
|
||||||
|
|
||||||
|
|
||||||
def request_provider_data_context(
|
def request_provider_data_context(
|
||||||
headers: dict[str, str], auth_attributes: dict[str, list[str]] | None = None
|
headers: dict[str, str],
|
||||||
|
auth_attributes: dict[str, list[str]] | None = None,
|
||||||
|
logged_in_user: str | None = None,
|
||||||
) -> AbstractContextManager:
|
) -> AbstractContextManager:
|
||||||
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
|
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
|
||||||
provider_data = parse_request_provider_data(headers)
|
provider_data = parse_request_provider_data(headers)
|
||||||
return RequestProviderDataContext(provider_data, auth_attributes)
|
return RequestProviderDataContext(provider_data, auth_attributes, logged_in_user)
|
||||||
|
|
||||||
|
|
||||||
def get_auth_attributes() -> dict[str, list[str]] | None:
|
def get_auth_attributes() -> dict[str, list[str]] | None:
|
||||||
|
@ -101,3 +108,9 @@ def get_auth_attributes() -> dict[str, list[str]] | None:
|
||||||
if not provider_data:
|
if not provider_data:
|
||||||
return None
|
return None
|
||||||
return provider_data.get("__auth_attributes")
|
return provider_data.get("__auth_attributes")
|
||||||
|
|
||||||
|
|
||||||
|
def get_logged_in_user() -> str | None:
|
||||||
|
"""Helper to retrieve logged in user from the provider data context"""
|
||||||
|
provider_data = PROVIDER_DATA_VAR.get()
|
||||||
|
return provider_data.get("__logged_in_user")
|
||||||
|
|
|
@ -209,11 +209,15 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||||
async def endpoint(request: Request, **kwargs):
|
async def endpoint(request: Request, **kwargs):
|
||||||
# Get auth attributes from the request scope
|
# Get auth attributes from the request scope
|
||||||
user_attributes = request.scope.get("user_attributes", {})
|
user_attributes = request.scope.get("user_attributes", {})
|
||||||
|
principal = request.scope.get("principal", None)
|
||||||
|
|
||||||
await log_request_pre_validation(request)
|
await log_request_pre_validation(request)
|
||||||
|
|
||||||
|
# TODO: before request execution starts, we need to check for authorization
|
||||||
|
# so we can send back 40X errors before StreamingResponse starts (and returns a 200).
|
||||||
|
|
||||||
# Use context manager with both provider data and auth attributes
|
# Use context manager with both provider data and auth attributes
|
||||||
with request_provider_data_context(request.headers, user_attributes):
|
with request_provider_data_context(request.headers, user_attributes, principal):
|
||||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -34,6 +34,7 @@ from llama_stack.apis.telemetry import Telemetry
|
||||||
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
||||||
from llama_stack.apis.vector_dbs import VectorDBs
|
from llama_stack.apis.vector_dbs import VectorDBs
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
|
from llama_stack.distribution.credentials import DistributionCredentialsConfig, DistributionCredentialsImpl
|
||||||
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl
|
from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||||
|
@ -199,24 +200,30 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None:
|
async def instantiate_internal_impls(impls: dict[Api, Any], run_config: StackRunConfig) -> dict[Api, Any]:
|
||||||
"""Add internal implementations (inspect and providers) to the implementations dictionary.
|
"""Add internal implementations (inspect, providers, credentials)."""
|
||||||
|
|
||||||
Args:
|
|
||||||
impls: Dictionary of API implementations
|
|
||||||
run_config: Stack run configuration
|
|
||||||
"""
|
|
||||||
inspect_impl = DistributionInspectImpl(
|
inspect_impl = DistributionInspectImpl(
|
||||||
DistributionInspectConfig(run_config=run_config),
|
DistributionInspectConfig(run_config=run_config),
|
||||||
deps=impls,
|
deps=impls,
|
||||||
)
|
)
|
||||||
impls[Api.inspect] = inspect_impl
|
await inspect_impl.initialize()
|
||||||
|
|
||||||
providers_impl = ProviderImpl(
|
providers_impl = ProviderImpl(
|
||||||
ProviderImplConfig(run_config=run_config),
|
ProviderImplConfig(run_config=run_config),
|
||||||
deps=impls,
|
deps=impls,
|
||||||
)
|
)
|
||||||
impls[Api.providers] = providers_impl
|
await providers_impl.initialize()
|
||||||
|
|
||||||
|
credentials_impl = DistributionCredentialsImpl(
|
||||||
|
DistributionCredentialsConfig(kvstore=run_config.credentials_store),
|
||||||
|
deps=impls,
|
||||||
|
)
|
||||||
|
await credentials_impl.initialize()
|
||||||
|
return {
|
||||||
|
Api.inspect: inspect_impl,
|
||||||
|
Api.providers: providers_impl,
|
||||||
|
Api.credentials: credentials_impl,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||||
|
@ -228,7 +235,14 @@ async def construct_stack(
|
||||||
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
|
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
|
||||||
|
|
||||||
# Add internal implementations after all other providers are resolved
|
# Add internal implementations after all other providers are resolved
|
||||||
add_internal_implementations(impls, run_config)
|
internal_impls = await instantiate_internal_impls(impls, run_config)
|
||||||
|
impls.update(internal_impls)
|
||||||
|
|
||||||
|
# credentials_store = internal_impls[Api.credentials]
|
||||||
|
# for impl in impls.values():
|
||||||
|
# # in an ideal world, we would pass the credentials store as a dependency
|
||||||
|
# if hasattr(impl, "credentials_store"):
|
||||||
|
# impl.credentials_store = credentials_store
|
||||||
|
|
||||||
await register_resources(run_config, impls)
|
await register_resources(run_config, impls)
|
||||||
return impls
|
return impls
|
||||||
|
|
|
@ -83,5 +83,6 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig",
|
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig",
|
||||||
pip_packages=["mcp"],
|
pip_packages=["mcp"],
|
||||||
),
|
),
|
||||||
|
api_dependencies=[Api.credentials],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -4,8 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.apis.datatypes import Api
|
||||||
|
|
||||||
from .config import ModelContextProtocolConfig
|
from .config import ModelContextProtocolConfig
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,9 +17,9 @@ class ModelContextProtocolToolProviderDataValidator(BaseModel):
|
||||||
api_key: str
|
api_key: str
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: ModelContextProtocolConfig, _deps):
|
async def get_adapter_impl(config: ModelContextProtocolConfig, deps: dict[Api, Any]):
|
||||||
from .model_context_protocol import ModelContextProtocolToolRuntimeImpl
|
from .model_context_protocol import ModelContextProtocolToolRuntimeImpl
|
||||||
|
|
||||||
impl = ModelContextProtocolToolRuntimeImpl(config)
|
impl = ModelContextProtocolToolRuntimeImpl(config, deps)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -4,13 +4,17 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any, cast
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import exceptiongroup
|
||||||
|
import httpx
|
||||||
from mcp import ClientSession
|
from mcp import ClientSession
|
||||||
from mcp.client.sse import sse_client
|
from mcp.client.sse import sse_client
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
|
from llama_stack.apis.datatypes import Api
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolDefsResponse,
|
||||||
ToolDef,
|
ToolDef,
|
||||||
|
@ -18,14 +22,35 @@ from llama_stack.apis.tools import (
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
|
from llama_stack.distribution.credentials import AuthenticationRequiredError, CredentialsStore
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
|
||||||
from .config import ModelContextProtocolConfig
|
from .config import ModelContextProtocolConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, None]:
|
||||||
|
try:
|
||||||
|
async with sse_client(endpoint, headers=headers) as streams:
|
||||||
|
async with ClientSession(*streams) as session:
|
||||||
|
await session.initialize()
|
||||||
|
yield session
|
||||||
|
except BaseException as e:
|
||||||
|
# TODO: auto-discover auth metadata and cache it, add a nonce, create state
|
||||||
|
# which can be used to exchange the authorization code for an access token.
|
||||||
|
if isinstance(e, exceptiongroup.BaseExceptionGroup):
|
||||||
|
for exc in e.exceptions:
|
||||||
|
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401:
|
||||||
|
raise AuthenticationRequiredError(exc) from exc
|
||||||
|
elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401:
|
||||||
|
raise AuthenticationRequiredError(e) from e
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
def __init__(self, config: ModelContextProtocolConfig):
|
def __init__(self, config: ModelContextProtocolConfig, deps: dict[Api, Any]):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.credentials_store = cast(CredentialsStore, deps[Api.credentials])
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
@ -36,31 +61,30 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
if mcp_endpoint is None:
|
if mcp_endpoint is None:
|
||||||
raise ValueError("mcp_endpoint is required")
|
raise ValueError("mcp_endpoint is required")
|
||||||
|
|
||||||
|
headers = await self.get_headers()
|
||||||
tools = []
|
tools = []
|
||||||
async with sse_client(mcp_endpoint.uri) as streams:
|
async with sse_client_wrapper(mcp_endpoint.uri, headers) as session:
|
||||||
async with ClientSession(*streams) as session:
|
tools_result = await session.list_tools()
|
||||||
await session.initialize()
|
for tool in tools_result.tools:
|
||||||
tools_result = await session.list_tools()
|
parameters = []
|
||||||
for tool in tools_result.tools:
|
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
|
||||||
parameters = []
|
parameters.append(
|
||||||
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
|
ToolParameter(
|
||||||
parameters.append(
|
name=param_name,
|
||||||
ToolParameter(
|
parameter_type=param_schema.get("type", "string"),
|
||||||
name=param_name,
|
description=param_schema.get("description", ""),
|
||||||
parameter_type=param_schema.get("type", "string"),
|
|
||||||
description=param_schema.get("description", ""),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
tools.append(
|
|
||||||
ToolDef(
|
|
||||||
name=tool.name,
|
|
||||||
description=tool.description,
|
|
||||||
parameters=parameters,
|
|
||||||
metadata={
|
|
||||||
"endpoint": mcp_endpoint.uri,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
tools.append(
|
||||||
|
ToolDef(
|
||||||
|
name=tool.name,
|
||||||
|
description=tool.description,
|
||||||
|
parameters=parameters,
|
||||||
|
metadata={
|
||||||
|
"endpoint": mcp_endpoint.uri,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
return ListToolDefsResponse(data=tools)
|
return ListToolDefsResponse(data=tools)
|
||||||
|
|
||||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||||
|
@ -71,12 +95,18 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
if urlparse(endpoint).scheme not in ("http", "https"):
|
if urlparse(endpoint).scheme not in ("http", "https"):
|
||||||
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
|
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
|
||||||
|
|
||||||
async with sse_client(endpoint) as streams:
|
headers = await self.get_headers()
|
||||||
async with ClientSession(*streams) as session:
|
async with sse_client_wrapper(endpoint, headers) as session:
|
||||||
await session.initialize()
|
result = await session.call_tool(tool.identifier, kwargs)
|
||||||
result = await session.call_tool(tool.identifier, kwargs)
|
|
||||||
|
|
||||||
return ToolInvocationResult(
|
return ToolInvocationResult(
|
||||||
content="\n".join([result.model_dump_json() for result in result.content]),
|
content=[result.model_dump_json() for result in result.content],
|
||||||
error_code=1 if result.isError else 0,
|
error_code=1 if result.isError else 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def get_headers(self) -> dict[str, str]:
|
||||||
|
headers = {}
|
||||||
|
credentials = await self.credentials_store.get_credential(self.__provider_id__)
|
||||||
|
if credentials:
|
||||||
|
headers["Authorization"] = f"Bearer {credentials.token}"
|
||||||
|
return headers
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue