This commit is contained in:
Kai Wu 2025-08-03 14:01:27 -07:00
parent 4e19f15bca
commit dcc47c2008
6 changed files with 351 additions and 8927 deletions

View file

@ -7,202 +7,11 @@
import json
import os
import streamlit as st
from typing import Dict, List, Optional, Any
from llama_stack.distribution.ui.modules.api import llama_stack_api
from llama_stack_client import LlamaStackClient
from llama_stack_client.types.toolgroup_register_params import McpEndpoint
# Constants
CONFIG_DIR = os.path.expanduser("~/.llama_stack/mcp_servers")
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
# Ensure config directory exists
os.makedirs(CONFIG_DIR, exist_ok=True)
def load_config() -> Dict[str, Any]:
"""Load MCP server configurations from file."""
if os.path.exists(CONFIG_FILE):
try:
with open(CONFIG_FILE, "r") as f:
return json.load(f)
except json.JSONDecodeError:
st.error("Error loading MCP server configuration file. Using default configuration.")
# Default empty configuration
return {
"servers": {},
"inputs": []
}
def register_mcp_servers(config: Dict[str, Any]) -> None:
"""Register MCP servers as toolgroups with the LlamaStackClient."""
servers = config.get("servers", {})
inputs = config.get("inputs", [])
# Process inputs to resolve values
input_values = {}
for input_config in inputs:
input_id = input_config.get("id")
if input_id:
# Get value from session state if available
input_values[input_id] = st.session_state.get(f"input_value_{input_id}", "")
# Update provider data with MCP headers
mcp_headers = {}
# Register each server as a toolgroup
for server_name, server_config in servers.items():
url = server_config.get("url", "")
if not url:
continue
try:
# Register the MCP server as a toolgroup
toolgroup_id = f"mcp::{server_name}"
# Register the toolgroup
llama_stack_api.client.toolgroups.register(
toolgroup_id=toolgroup_id,
provider_id="model-context-protocol",
mcp_endpoint=url,
)
# Process headers for this server
headers = server_config.get("headers", {})
processed_headers = {}
for header_key, header_value in headers.items():
# Process input references in header values
if isinstance(header_value, str) and "${input:" in header_value:
# Extract input ID from ${input:input_id} format
input_id = header_value.split("${input:")[1].split("}")[0]
if input_id in input_values:
processed_headers[header_key] = input_values[input_id]
else:
processed_headers[header_key] = header_value
# Add headers to mcp_headers if there are any
if processed_headers:
mcp_headers[url] = processed_headers
except Exception as e:
st.error(f"Failed to register MCP server '{server_name}': {str(e)}")
# Update provider data with MCP headers if there are any
if mcp_headers:
provider_data = llama_stack_api.provider_data.copy()
provider_data["mcp_headers"] = mcp_headers
llama_stack_api.update_provider_data_dict(provider_data)
def save_config(config: Dict[str, Any]) -> None:
"""Save MCP server configurations to file."""
with open(CONFIG_FILE, "w") as f:
json.dump(config, f, indent=2)
# Register MCP servers as toolgroups
register_mcp_servers(config)
st.success("MCP server configuration saved successfully!")
def render_server_config(server_name: str, server_config: Dict[str, Any], config: Dict[str, Any]) -> Dict[str, Any]:
"""Render and edit configuration for a specific MCP server."""
st.subheader(f"Server: {server_name}")
# Server type
server_type = st.selectbox(
"Server Type",
["http", "websocket"],
index=0 if server_config.get("type", "http") == "http" else 1,
key=f"type_{server_name}"
)
# Server URL
url = st.text_input(
"Server URL",
value=server_config.get("url", ""),
key=f"url_{server_name}"
)
# Headers
st.write("Headers:")
headers = server_config.get("headers", {})
headers_container = st.container()
with headers_container:
# Display existing headers
for header_key, header_value in list(headers.items()):
col1, col2, col3 = st.columns([3, 6, 1])
with col1:
st.text(header_key)
with col2:
# Check if this is a reference to an input
if isinstance(header_value, str) and "${input:" in header_value:
st.text(header_value)
else:
st.text("********" if "token" in header_key.lower() or "auth" in header_key.lower() else header_value)
with col3:
if st.button("🗑️", key=f"delete_header_{server_name}_{header_key}"):
del headers[header_key]
# Add new header
st.write("Add Header:")
header_col1, header_col2 = st.columns([1, 1])
new_header_key = header_col1.text_input("Key", key=f"new_header_key_{server_name}")
new_header_value = header_col2.text_input("Value", key=f"new_header_value_{server_name}")
if st.button("Add Header", key=f"add_header_{server_name}"):
if new_header_key and new_header_value:
headers[new_header_key] = new_header_value
st.experimental_rerun()
# Construct updated server config
updated_server_config = {
"type": server_type,
"url": url,
"headers": headers
}
return updated_server_config
def render_input_config(input_config: Dict[str, Any], index: int) -> Dict[str, Any]:
"""Render and edit configuration for an input field."""
st.subheader(f"Input: {input_config.get('id', '')}")
col1, col2 = st.columns([1, 1])
input_type = col1.selectbox(
"Type",
["promptString"],
index=0,
key=f"input_type_{index}"
)
input_id = col2.text_input(
"ID",
value=input_config.get("id", ""),
key=f"input_id_{index}"
)
description = st.text_input(
"Description",
value=input_config.get("description", ""),
key=f"input_desc_{index}"
)
is_password = st.checkbox(
"Password Field",
value=input_config.get("password", False),
key=f"input_password_{index}"
)
# Construct updated input config
updated_input_config = {
"type": input_type,
"id": input_id,
"description": description,
"password": is_password
}
return updated_input_config
def main():
st.title("MCP Servers Configuration")
@ -214,181 +23,318 @@ def main():
MCP servers are registered as toolgroups with the ID format `mcp::{server_name}`.
""")
# Load existing configuration
config = load_config()
# MCP Server Configuration
st.header("MCP Server Configuration")
# Tabs for different sections
tab1, tab2, tab3, tab4 = st.tabs(["Servers", "Inputs", "Input Values", "JSON Config"])
# Servers Tab
with tab1:
st.header("Configured Servers")
# Create a form for MCP configuration
with st.form("mcp_server_form"):
# Server name
server_name = st.text_input(
"Server Name",
value="github",
help="A unique name for this MCP server. Will be used in the toolgroup ID as mcp::{name}."
)
# List existing servers
servers = config.get("servers", {})
if not servers:
st.info("No MCP servers configured yet. Add a new server below.")
# MCP URL
mcp_url = st.text_input(
"MCP URL",
value="https://api.githubcopilot.com/mcp/",
help="The URL of the MCP server."
)
# Server selection or creation
server_options = list(servers.keys()) + ["+ Add New Server"]
selected_server = st.selectbox("Select Server", server_options)
# Get the current value from session state
api_token = st.session_state.get(f"mcp_token_{server_name}", "")
if selected_server == "+ Add New Server":
new_server_name = st.text_input("New Server Name")
if new_server_name and st.button("Create Server"):
if new_server_name in servers:
st.error(f"Server '{new_server_name}' already exists.")
else:
servers[new_server_name] = {
"type": "http",
"url": "",
"headers": {}
}
st.experimental_rerun()
elif selected_server in servers:
# Edit existing server
updated_config = render_server_config(selected_server, servers[selected_server], config)
col1, col2 = st.columns([1, 1])
if col1.button("Update Server", key=f"update_{selected_server}"):
servers[selected_server] = updated_config
save_config(config)
if col2.button("Delete Server", key=f"delete_{selected_server}"):
del servers[selected_server]
save_config(config)
st.experimental_rerun()
# Inputs Tab
with tab2:
st.header("Input Configurations")
# Input field for API Bearer Token
mcp_token = st.text_input(
"API Bearer Token",
value=api_token,
type="password",
help="Enter your API Bearer Token. For GitHub Copilot, this should be a GitHub Personal Access Token with Copilot scope."
)
inputs = config.get("inputs", [])
if not inputs:
st.info("No input configurations defined yet. Add a new input below.")
# Input selection or creation
input_options = [f"{i.get('id', f'Input {idx}')} ({i.get('type', 'promptString')})" for idx, i in enumerate(inputs)]
input_options.append("+ Add New Input")
# Submit button
submit_button = st.form_submit_button("Save Configuration")
selected_input_option = st.selectbox("Select Input", input_options)
if selected_input_option == "+ Add New Input":
if st.button("Create Input"):
inputs.append({
"type": "promptString",
"id": f"input_{len(inputs)}",
"description": "",
"password": False
})
save_config(config)
st.experimental_rerun()
else:
# Edit existing input
selected_idx = input_options.index(selected_input_option)
if selected_idx < len(inputs):
updated_input = render_input_config(inputs[selected_idx], selected_idx)
if submit_button:
if not server_name:
st.error("Server name is required.")
elif not mcp_url:
st.error("MCP URL is required.")
else:
# Store the token in session state
st.session_state[f"mcp_token_{server_name}"] = mcp_token
col1, col2 = st.columns([1, 1])
if col1.button("Update Input", key=f"update_input_{selected_idx}"):
inputs[selected_idx] = updated_input
save_config(config)
if col2.button("Delete Input", key=f"delete_input_{selected_idx}"):
inputs.pop(selected_idx)
save_config(config)
st.experimental_rerun()
try:
# Register the MCP server as a toolgroup
toolgroup_id = f"mcp::{server_name}"
try:
llama_stack_api.client.toolgroups.register(
toolgroup_id=toolgroup_id,
provider_id="model-context-protocol",
mcp_endpoint=McpEndpoint(uri=mcp_url),
timeout=10.0, # Set a reasonable timeout
)
except Exception as e:
if "timeout" in str(e).lower():
st.warning(f"Registration timed out, but configuration will still be saved. Error: {str(e)}")
else:
raise
# Update provider data with MCP headers
# Check if provider_data attribute exists
if not hasattr(llama_stack_api, "provider_data"):
llama_stack_api.provider_data = {}
provider_data = llama_stack_api.provider_data.copy()
if "mcp_headers" not in provider_data:
provider_data["mcp_headers"] = {}
# Add MCP headers
if mcp_token:
# Clean the token (remove 'Bearer ' prefix if present)
clean_token = mcp_token
if clean_token.lower().startswith("bearer "):
clean_token = clean_token[7:]
# Set the headers for this MCP server
# The key needs to be the exact URL used in the MCP endpoint
provider_data["mcp_headers"][mcp_url] = {
"Authorization": f"Bearer {clean_token}"
}
# Debug information
st.info(f"Set authentication headers for {mcp_url}: Bearer {clean_token[:4]}...")
# Also set the token directly in the provider_data using different formats
# This increases the chance that one of them will work
provider_data[f"{server_name}_api_key"] = clean_token
provider_data[f"{server_name}_token"] = clean_token
provider_data[f"{server_name}_mcp_token"] = clean_token
# Display the current provider_data for debugging
st.write("Current provider_data:")
# Mask the token for display
masked_token = ""
if clean_token:
if len(clean_token) > 8:
# Show first 2 and last 2 characters, mask the middle with asterisks
masked_token = f"Bearer {clean_token[:2]}{'*' * 6}{clean_token[-2:]}"
else:
# For short tokens, just show first 2 chars and asterisks
masked_token = f"Bearer {clean_token[:2]}{'*' * 4}"
# Create a sanitized version of mcp_headers for display
sanitized_headers = {}
for url, headers in provider_data.get("mcp_headers", {}).items():
sanitized_headers[url] = {}
for header_key, header_value in headers.items():
if header_key.lower() == "authorization" and isinstance(header_value, str):
if header_value.lower().startswith("bearer "):
token = header_value[7:]
if len(token) > 8:
sanitized_headers[url][header_key] = f"Bearer {token[:2]}{'*' * 6}{token[-2:]}"
else:
sanitized_headers[url][header_key] = f"Bearer {token[:2]}{'*' * 4}"
else:
sanitized_headers[url][header_key] = f"{header_value[:2]}{'*' * 6}"
else:
sanitized_headers[url][header_key] = header_value
st.json({
"mcp_headers": sanitized_headers,
f"{server_name}_api_key": masked_token
})
# Add a note about authentication
st.warning("""
**Important Authentication Notes:**
1. If you encounter authentication errors when using this MCP server,
try restarting the UI application to ensure the authentication headers are properly applied.
2. For GitHub Copilot, make sure you're using a valid GitHub Personal Access Token with the 'copilot' scope.
3. When using the MCP server in code, you may need to explicitly pass the authentication headers:
```python
import json
from llama_stack_client import Agent
agent = Agent(
model="llama-3-70b-instruct",
tools=["mcp::github"],
extra_headers={
"X-LlamaStack-Provider-Data": json.dumps({
"mcp_headers": {
"https://api.githubcopilot.com/mcp/": {
"Authorization": "Bearer YOUR_TOKEN_HERE"
}
}
})
}
)
```
""")
# Update the client with the new provider data
if hasattr(llama_stack_api, "update_provider_data_dict"):
llama_stack_api.update_provider_data_dict(provider_data)
else:
# Fallback implementation if method doesn't exist
llama_stack_api.provider_data = provider_data
# Reinitialize the client with updated provider data
llama_stack_api.client = LlamaStackClient(
base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"),
provider_data=provider_data,
)
st.success(f"MCP server '{server_name}' configured successfully!")
except Exception as e:
st.error(f"Failed to configure MCP server '{server_name}': {str(e)}")
# Input Values Tab
with tab3:
st.header("Input Values")
inputs = config.get("inputs", [])
if not inputs:
st.info("No input configurations defined yet. Add inputs in the Inputs tab.")
else:
st.write("Enter values for the configured inputs:")
for input_config in inputs:
input_id = input_config.get("id", "")
description = input_config.get("description", "")
is_password = input_config.get("password", False)
# Get value from session state if available
current_value = st.session_state.get(f"input_value_{input_id}", "")
# Input field for value
value = st.text_input(
f"{description} ({input_id})",
value=current_value,
type="password" if is_password else "default",
key=f"input_value_{input_id}"
)
if st.button("Save Input Values"):
# Values are automatically saved to session state
# Register MCP servers to apply the new values
register_mcp_servers(config)
st.success("Input values saved successfully!")
# JSON Config Tab
with tab4:
st.header("Raw JSON Configuration")
# Display and edit raw JSON
json_str = json.dumps(config, indent=2)
edited_json = st.text_area("Edit JSON Configuration", json_str, height=400)
if st.button("Update from JSON"):
try:
updated_config = json.loads(edited_json)
save_config(updated_config)
st.experimental_rerun()
except json.JSONDecodeError as e:
st.error(f"Invalid JSON: {str(e)}")
# Example configuration and usage
with st.expander("Example Configuration and Usage"):
# Usage example
with st.expander("Usage Example"):
st.markdown("""
### Example Configuration
### Using MCP Servers in Code
```json
{
"servers": {
"github": {
"type": "http",
"url": "https://api.githubcopilot.com/mcp/",
"headers": {
"Authorization": "Bearer ${input:github_mcp_pat}"
}
}
},
"inputs": [
{
"type": "promptString",
"id": "github_mcp_pat",
"description": "GitHub Personal Access Token",
"password": true
}
]
}
```
### Usage in Code
Once registered, you can use the MCP server in your code:
Once configured, you can use the MCP server in your code:
```python
from llama_stack_client import Agent
agent = Agent(
model="llama-3-70b-instruct",
tools=["mcp::github"], # Use the registered MCP server
tools=["mcp::your_server_name"], # Use the registered MCP server
)
agent.create_turn("Use GitHub Copilot to help me write a function")
agent.create_turn("Use the MCP server to help me with a task")
```
### With Explicit Authentication Headers
If you encounter authentication issues, you can explicitly pass the authentication headers:
```python
import json
from llama_stack_client import Agent
agent = Agent(
model="llama-3-70b-instruct",
tools=["mcp::github"],
extra_headers={
"X-LlamaStack-Provider-Data": json.dumps({
"mcp_headers": {
"https://api.githubcopilot.com/mcp/": {
"Authorization": "Bearer YOUR_TOKEN_HERE"
}
}
})
}
)
```
""")
# Display registered toolgroups
st.header("Registered MCP Toolgroups")
try:
toolgroups = llama_stack_api.client.toolgroups.list(timeout=5.0) # Set a reasonable timeout
# Debug the structure of the toolgroups
if toolgroups:
# Check the first toolgroup to determine its structure
first_toolgroup = toolgroups[0] if toolgroups else None
if first_toolgroup:
# Get the attribute names
attr_names = dir(first_toolgroup)
# The ToolGroup class has an 'identifier' attribute as per the API definition
id_attr = 'identifier'
# Filter MCP toolgroups based on the identifier attribute
mcp_toolgroups = []
for tg in toolgroups:
if hasattr(tg, 'identifier') and isinstance(tg.identifier, str) and tg.identifier.startswith("mcp::"):
mcp_toolgroups.append(tg)
if mcp_toolgroups:
# Display MCP servers with unregister buttons
st.write("Click the button next to a server to unregister it:")
for tg in mcp_toolgroups:
col1, col2, col3, col4 = st.columns([3, 2, 3, 1])
with col1:
st.write(f"**{tg.identifier}**")
with col2:
st.write(tg.provider_id)
with col3:
if hasattr(tg, 'mcp_endpoint') and tg.mcp_endpoint:
st.write(tg.mcp_endpoint.uri)
else:
st.write("N/A")
with col4:
# Extract server name from identifier (remove "mcp::" prefix)
server_name = tg.identifier.replace("mcp::", "")
if st.button("Unregister", key=f"unregister_{server_name}"):
try:
# Call the unregister API
llama_stack_api.client.toolgroups.unregister(
toolgroup_id=tg.identifier,
timeout=5.0
)
# Also clean up provider data
if hasattr(llama_stack_api, "provider_data"):
provider_data = llama_stack_api.provider_data.copy()
# Remove MCP headers for this server if they exist
if "mcp_headers" in provider_data and hasattr(tg, 'mcp_endpoint') and tg.mcp_endpoint:
if tg.mcp_endpoint.uri in provider_data["mcp_headers"]:
del provider_data["mcp_headers"][tg.mcp_endpoint.uri]
# Remove server-specific tokens
keys_to_remove = [
f"{server_name}_api_key",
f"{server_name}_token",
f"{server_name}_mcp_token"
]
for key in keys_to_remove:
if key in provider_data:
del provider_data[key]
# Update the client with the modified provider data
if hasattr(llama_stack_api, "update_provider_data_dict"):
llama_stack_api.update_provider_data_dict(provider_data)
else:
# Fallback implementation
llama_stack_api.provider_data = provider_data
llama_stack_api.client = LlamaStackClient(
base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"),
provider_data=provider_data,
)
st.success(f"Successfully unregistered {tg.identifier}")
st.rerun() # Refresh the page to update the list
except Exception as e:
st.error(f"Failed to unregister {tg.identifier}: {str(e)}")
else:
st.info("No MCP toolgroups registered yet.")
else:
st.info("No toolgroups found.")
else:
st.info("No toolgroups returned from the API.")
except Exception as e:
if "timeout" in str(e).lower():
st.warning("Listing toolgroups timed out. The server might be busy or unreachable.")
else:
st.error(f"Failed to list toolgroups: {str(e)}")
if __name__ == "__main__":
main()

View file

@ -80,9 +80,14 @@ def tool_chat_page():
total_tools = 0
for toolgroup_id in toolgroup_selection:
tools = client.tools.list(toolgroup_id=toolgroup_id)
grouped_tools[toolgroup_id] = [tool.identifier for tool in tools]
total_tools += len(tools)
try:
# Add a timeout of 5 seconds to prevent UI freezing
tools = client.tools.list(toolgroup_id=toolgroup_id, timeout=5.0)
grouped_tools[toolgroup_id] = [tool.identifier for tool in tools]
total_tools += len(tools)
except Exception as e:
st.warning(f"Failed to list tools for {toolgroup_id}: {str(e)}")
grouped_tools[toolgroup_id] = []
st.markdown(f"Active Tools: 🛠 {total_tools}")
@ -126,25 +131,56 @@ def tool_chat_page():
@st.cache_resource
def create_agent():
if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT:
return ReActAgent(
client=client,
model=model,
tools=toolgroup_selection,
response_format={
"type": "json_schema",
"json_schema": ReActOutput.model_json_schema(),
},
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
)
else:
return Agent(
client,
model=model,
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
tools=toolgroup_selection,
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
try:
if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT:
return ReActAgent(
client=client,
model=model,
tools=toolgroup_selection,
response_format={
"type": "json_schema",
"json_schema": ReActOutput.model_json_schema(),
},
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
)
else:
return Agent(
client,
model=model,
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
tools=toolgroup_selection,
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
)
except Exception as e:
# Log the error
st.error(f"Failed to create agent: {str(e)}")
# Create a fallback agent without tools
st.warning(
"Creating a fallback agent without tools due to an error. "
"Some functionality may be limited. Try refreshing the page or selecting different tools."
)
# Return a basic agent without tools
if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT:
return ReActAgent(
client=client,
model=model,
tools=[], # No tools
response_format={
"type": "json_schema",
"json_schema": ReActOutput.model_json_schema(),
},
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
)
else:
return Agent(
client,
model=model,
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
tools=[], # No tools
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
)
st.session_state.agent_type = agent_type