This commit is contained in:
Kai Wu 2025-08-03 14:01:27 -07:00
parent 4e19f15bca
commit dcc47c2008
6 changed files with 351 additions and 8927 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load diff

View file

@ -39,7 +39,7 @@ spec:
image: vllm/vllm-openai:latest
command: ["/bin/sh", "-c"]
args:
- "vllm serve ${INFERENCE_MODEL} --enforce-eager --max-model-len 8192 --gpu-memory-utilization 0.7 --enable-auto-tool-choice --tool-call-parser llama3_json --max-num-seqs 4 --port 8001"
- "vllm serve ${INFERENCE_MODEL} --enforce-eager --max-model-len 100000 --gpu-memory-utilization 0.9 --enable-auto-tool-choice --tool-call-parser llama3_json --max-num-seqs 2 --port 8001"
env:
- name: INFERENCE_MODEL
value: "${INFERENCE_MODEL}"

View file

@ -7,202 +7,11 @@
import json
import os
import streamlit as st
from typing import Dict, List, Optional, Any
from llama_stack.distribution.ui.modules.api import llama_stack_api
from llama_stack_client import LlamaStackClient
from llama_stack_client.types.toolgroup_register_params import McpEndpoint
# Constants
CONFIG_DIR = os.path.expanduser("~/.llama_stack/mcp_servers")
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
# Ensure config directory exists
os.makedirs(CONFIG_DIR, exist_ok=True)
def load_config() -> Dict[str, Any]:
"""Load MCP server configurations from file."""
if os.path.exists(CONFIG_FILE):
try:
with open(CONFIG_FILE, "r") as f:
return json.load(f)
except json.JSONDecodeError:
st.error("Error loading MCP server configuration file. Using default configuration.")
# Default empty configuration
return {
"servers": {},
"inputs": []
}
def register_mcp_servers(config: Dict[str, Any]) -> None:
"""Register MCP servers as toolgroups with the LlamaStackClient."""
servers = config.get("servers", {})
inputs = config.get("inputs", [])
# Process inputs to resolve values
input_values = {}
for input_config in inputs:
input_id = input_config.get("id")
if input_id:
# Get value from session state if available
input_values[input_id] = st.session_state.get(f"input_value_{input_id}", "")
# Update provider data with MCP headers
mcp_headers = {}
# Register each server as a toolgroup
for server_name, server_config in servers.items():
url = server_config.get("url", "")
if not url:
continue
try:
# Register the MCP server as a toolgroup
toolgroup_id = f"mcp::{server_name}"
# Register the toolgroup
llama_stack_api.client.toolgroups.register(
toolgroup_id=toolgroup_id,
provider_id="model-context-protocol",
mcp_endpoint=url,
)
# Process headers for this server
headers = server_config.get("headers", {})
processed_headers = {}
for header_key, header_value in headers.items():
# Process input references in header values
if isinstance(header_value, str) and "${input:" in header_value:
# Extract input ID from ${input:input_id} format
input_id = header_value.split("${input:")[1].split("}")[0]
if input_id in input_values:
processed_headers[header_key] = input_values[input_id]
else:
processed_headers[header_key] = header_value
# Add headers to mcp_headers if there are any
if processed_headers:
mcp_headers[url] = processed_headers
except Exception as e:
st.error(f"Failed to register MCP server '{server_name}': {str(e)}")
# Update provider data with MCP headers if there are any
if mcp_headers:
provider_data = llama_stack_api.provider_data.copy()
provider_data["mcp_headers"] = mcp_headers
llama_stack_api.update_provider_data_dict(provider_data)
def save_config(config: Dict[str, Any]) -> None:
"""Save MCP server configurations to file."""
with open(CONFIG_FILE, "w") as f:
json.dump(config, f, indent=2)
# Register MCP servers as toolgroups
register_mcp_servers(config)
st.success("MCP server configuration saved successfully!")
def render_server_config(server_name: str, server_config: Dict[str, Any], config: Dict[str, Any]) -> Dict[str, Any]:
"""Render and edit configuration for a specific MCP server."""
st.subheader(f"Server: {server_name}")
# Server type
server_type = st.selectbox(
"Server Type",
["http", "websocket"],
index=0 if server_config.get("type", "http") == "http" else 1,
key=f"type_{server_name}"
)
# Server URL
url = st.text_input(
"Server URL",
value=server_config.get("url", ""),
key=f"url_{server_name}"
)
# Headers
st.write("Headers:")
headers = server_config.get("headers", {})
headers_container = st.container()
with headers_container:
# Display existing headers
for header_key, header_value in list(headers.items()):
col1, col2, col3 = st.columns([3, 6, 1])
with col1:
st.text(header_key)
with col2:
# Check if this is a reference to an input
if isinstance(header_value, str) and "${input:" in header_value:
st.text(header_value)
else:
st.text("********" if "token" in header_key.lower() or "auth" in header_key.lower() else header_value)
with col3:
if st.button("🗑️", key=f"delete_header_{server_name}_{header_key}"):
del headers[header_key]
# Add new header
st.write("Add Header:")
header_col1, header_col2 = st.columns([1, 1])
new_header_key = header_col1.text_input("Key", key=f"new_header_key_{server_name}")
new_header_value = header_col2.text_input("Value", key=f"new_header_value_{server_name}")
if st.button("Add Header", key=f"add_header_{server_name}"):
if new_header_key and new_header_value:
headers[new_header_key] = new_header_value
st.experimental_rerun()
# Construct updated server config
updated_server_config = {
"type": server_type,
"url": url,
"headers": headers
}
return updated_server_config
def render_input_config(input_config: Dict[str, Any], index: int) -> Dict[str, Any]:
"""Render and edit configuration for an input field."""
st.subheader(f"Input: {input_config.get('id', '')}")
col1, col2 = st.columns([1, 1])
input_type = col1.selectbox(
"Type",
["promptString"],
index=0,
key=f"input_type_{index}"
)
input_id = col2.text_input(
"ID",
value=input_config.get("id", ""),
key=f"input_id_{index}"
)
description = st.text_input(
"Description",
value=input_config.get("description", ""),
key=f"input_desc_{index}"
)
is_password = st.checkbox(
"Password Field",
value=input_config.get("password", False),
key=f"input_password_{index}"
)
# Construct updated input config
updated_input_config = {
"type": input_type,
"id": input_id,
"description": description,
"password": is_password
}
return updated_input_config
def main():
st.title("MCP Servers Configuration")
@ -214,181 +23,318 @@ def main():
MCP servers are registered as toolgroups with the ID format `mcp::{server_name}`.
""")
# Load existing configuration
config = load_config()
# MCP Server Configuration
st.header("MCP Server Configuration")
# Tabs for different sections
tab1, tab2, tab3, tab4 = st.tabs(["Servers", "Inputs", "Input Values", "JSON Config"])
# Servers Tab
with tab1:
st.header("Configured Servers")
# Create a form for MCP configuration
with st.form("mcp_server_form"):
# Server name
server_name = st.text_input(
"Server Name",
value="github",
help="A unique name for this MCP server. Will be used in the toolgroup ID as mcp::{name}."
)
# List existing servers
servers = config.get("servers", {})
if not servers:
st.info("No MCP servers configured yet. Add a new server below.")
# MCP URL
mcp_url = st.text_input(
"MCP URL",
value="https://api.githubcopilot.com/mcp/",
help="The URL of the MCP server."
)
# Server selection or creation
server_options = list(servers.keys()) + ["+ Add New Server"]
selected_server = st.selectbox("Select Server", server_options)
# Get the current value from session state
api_token = st.session_state.get(f"mcp_token_{server_name}", "")
if selected_server == "+ Add New Server":
new_server_name = st.text_input("New Server Name")
if new_server_name and st.button("Create Server"):
if new_server_name in servers:
st.error(f"Server '{new_server_name}' already exists.")
else:
servers[new_server_name] = {
"type": "http",
"url": "",
"headers": {}
}
st.experimental_rerun()
elif selected_server in servers:
# Edit existing server
updated_config = render_server_config(selected_server, servers[selected_server], config)
col1, col2 = st.columns([1, 1])
if col1.button("Update Server", key=f"update_{selected_server}"):
servers[selected_server] = updated_config
save_config(config)
if col2.button("Delete Server", key=f"delete_{selected_server}"):
del servers[selected_server]
save_config(config)
st.experimental_rerun()
# Inputs Tab
with tab2:
st.header("Input Configurations")
# Input field for API Bearer Token
mcp_token = st.text_input(
"API Bearer Token",
value=api_token,
type="password",
help="Enter your API Bearer Token. For GitHub Copilot, this should be a GitHub Personal Access Token with Copilot scope."
)
inputs = config.get("inputs", [])
if not inputs:
st.info("No input configurations defined yet. Add a new input below.")
# Input selection or creation
input_options = [f"{i.get('id', f'Input {idx}')} ({i.get('type', 'promptString')})" for idx, i in enumerate(inputs)]
input_options.append("+ Add New Input")
# Submit button
submit_button = st.form_submit_button("Save Configuration")
selected_input_option = st.selectbox("Select Input", input_options)
if selected_input_option == "+ Add New Input":
if st.button("Create Input"):
inputs.append({
"type": "promptString",
"id": f"input_{len(inputs)}",
"description": "",
"password": False
})
save_config(config)
st.experimental_rerun()
else:
# Edit existing input
selected_idx = input_options.index(selected_input_option)
if selected_idx < len(inputs):
updated_input = render_input_config(inputs[selected_idx], selected_idx)
if submit_button:
if not server_name:
st.error("Server name is required.")
elif not mcp_url:
st.error("MCP URL is required.")
else:
# Store the token in session state
st.session_state[f"mcp_token_{server_name}"] = mcp_token
col1, col2 = st.columns([1, 1])
if col1.button("Update Input", key=f"update_input_{selected_idx}"):
inputs[selected_idx] = updated_input
save_config(config)
if col2.button("Delete Input", key=f"delete_input_{selected_idx}"):
inputs.pop(selected_idx)
save_config(config)
st.experimental_rerun()
try:
# Register the MCP server as a toolgroup
toolgroup_id = f"mcp::{server_name}"
try:
llama_stack_api.client.toolgroups.register(
toolgroup_id=toolgroup_id,
provider_id="model-context-protocol",
mcp_endpoint=McpEndpoint(uri=mcp_url),
timeout=10.0, # Set a reasonable timeout
)
except Exception as e:
if "timeout" in str(e).lower():
st.warning(f"Registration timed out, but configuration will still be saved. Error: {str(e)}")
else:
raise
# Update provider data with MCP headers
# Check if provider_data attribute exists
if not hasattr(llama_stack_api, "provider_data"):
llama_stack_api.provider_data = {}
provider_data = llama_stack_api.provider_data.copy()
if "mcp_headers" not in provider_data:
provider_data["mcp_headers"] = {}
# Add MCP headers
if mcp_token:
# Clean the token (remove 'Bearer ' prefix if present)
clean_token = mcp_token
if clean_token.lower().startswith("bearer "):
clean_token = clean_token[7:]
# Set the headers for this MCP server
# The key needs to be the exact URL used in the MCP endpoint
provider_data["mcp_headers"][mcp_url] = {
"Authorization": f"Bearer {clean_token}"
}
# Debug information
st.info(f"Set authentication headers for {mcp_url}: Bearer {clean_token[:4]}...")
# Also set the token directly in the provider_data using different formats
# This increases the chance that one of them will work
provider_data[f"{server_name}_api_key"] = clean_token
provider_data[f"{server_name}_token"] = clean_token
provider_data[f"{server_name}_mcp_token"] = clean_token
# Display the current provider_data for debugging
st.write("Current provider_data:")
# Mask the token for display
masked_token = ""
if clean_token:
if len(clean_token) > 8:
# Show first 2 and last 2 characters, mask the middle with asterisks
masked_token = f"Bearer {clean_token[:2]}{'*' * 6}{clean_token[-2:]}"
else:
# For short tokens, just show first 2 chars and asterisks
masked_token = f"Bearer {clean_token[:2]}{'*' * 4}"
# Create a sanitized version of mcp_headers for display
sanitized_headers = {}
for url, headers in provider_data.get("mcp_headers", {}).items():
sanitized_headers[url] = {}
for header_key, header_value in headers.items():
if header_key.lower() == "authorization" and isinstance(header_value, str):
if header_value.lower().startswith("bearer "):
token = header_value[7:]
if len(token) > 8:
sanitized_headers[url][header_key] = f"Bearer {token[:2]}{'*' * 6}{token[-2:]}"
else:
sanitized_headers[url][header_key] = f"Bearer {token[:2]}{'*' * 4}"
else:
sanitized_headers[url][header_key] = f"{header_value[:2]}{'*' * 6}"
else:
sanitized_headers[url][header_key] = header_value
st.json({
"mcp_headers": sanitized_headers,
f"{server_name}_api_key": masked_token
})
# Add a note about authentication
st.warning("""
**Important Authentication Notes:**
1. If you encounter authentication errors when using this MCP server,
try restarting the UI application to ensure the authentication headers are properly applied.
2. For GitHub Copilot, make sure you're using a valid GitHub Personal Access Token with the 'copilot' scope.
3. When using the MCP server in code, you may need to explicitly pass the authentication headers:
```python
import json
from llama_stack_client import Agent
agent = Agent(
model="llama-3-70b-instruct",
tools=["mcp::github"],
extra_headers={
"X-LlamaStack-Provider-Data": json.dumps({
"mcp_headers": {
"https://api.githubcopilot.com/mcp/": {
"Authorization": "Bearer YOUR_TOKEN_HERE"
}
}
})
}
)
```
""")
# Update the client with the new provider data
if hasattr(llama_stack_api, "update_provider_data_dict"):
llama_stack_api.update_provider_data_dict(provider_data)
else:
# Fallback implementation if method doesn't exist
llama_stack_api.provider_data = provider_data
# Reinitialize the client with updated provider data
llama_stack_api.client = LlamaStackClient(
base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"),
provider_data=provider_data,
)
st.success(f"MCP server '{server_name}' configured successfully!")
except Exception as e:
st.error(f"Failed to configure MCP server '{server_name}': {str(e)}")
# Input Values Tab
with tab3:
st.header("Input Values")
inputs = config.get("inputs", [])
if not inputs:
st.info("No input configurations defined yet. Add inputs in the Inputs tab.")
else:
st.write("Enter values for the configured inputs:")
for input_config in inputs:
input_id = input_config.get("id", "")
description = input_config.get("description", "")
is_password = input_config.get("password", False)
# Get value from session state if available
current_value = st.session_state.get(f"input_value_{input_id}", "")
# Input field for value
value = st.text_input(
f"{description} ({input_id})",
value=current_value,
type="password" if is_password else "default",
key=f"input_value_{input_id}"
)
if st.button("Save Input Values"):
# Values are automatically saved to session state
# Register MCP servers to apply the new values
register_mcp_servers(config)
st.success("Input values saved successfully!")
# JSON Config Tab
with tab4:
st.header("Raw JSON Configuration")
# Display and edit raw JSON
json_str = json.dumps(config, indent=2)
edited_json = st.text_area("Edit JSON Configuration", json_str, height=400)
if st.button("Update from JSON"):
try:
updated_config = json.loads(edited_json)
save_config(updated_config)
st.experimental_rerun()
except json.JSONDecodeError as e:
st.error(f"Invalid JSON: {str(e)}")
# Example configuration and usage
with st.expander("Example Configuration and Usage"):
# Usage example
with st.expander("Usage Example"):
st.markdown("""
### Example Configuration
### Using MCP Servers in Code
```json
{
"servers": {
"github": {
"type": "http",
"url": "https://api.githubcopilot.com/mcp/",
"headers": {
"Authorization": "Bearer ${input:github_mcp_pat}"
}
}
},
"inputs": [
{
"type": "promptString",
"id": "github_mcp_pat",
"description": "GitHub Personal Access Token",
"password": true
}
]
}
```
### Usage in Code
Once registered, you can use the MCP server in your code:
Once configured, you can use the MCP server in your code:
```python
from llama_stack_client import Agent
agent = Agent(
model="llama-3-70b-instruct",
tools=["mcp::github"], # Use the registered MCP server
tools=["mcp::your_server_name"], # Use the registered MCP server
)
agent.create_turn("Use GitHub Copilot to help me write a function")
agent.create_turn("Use the MCP server to help me with a task")
```
### With Explicit Authentication Headers
If you encounter authentication issues, you can explicitly pass the authentication headers:
```python
import json
from llama_stack_client import Agent
agent = Agent(
model="llama-3-70b-instruct",
tools=["mcp::github"],
extra_headers={
"X-LlamaStack-Provider-Data": json.dumps({
"mcp_headers": {
"https://api.githubcopilot.com/mcp/": {
"Authorization": "Bearer YOUR_TOKEN_HERE"
}
}
})
}
)
```
""")
# Display registered toolgroups
st.header("Registered MCP Toolgroups")
try:
toolgroups = llama_stack_api.client.toolgroups.list(timeout=5.0) # Set a reasonable timeout
# Debug the structure of the toolgroups
if toolgroups:
# Check the first toolgroup to determine its structure
first_toolgroup = toolgroups[0] if toolgroups else None
if first_toolgroup:
# Get the attribute names
attr_names = dir(first_toolgroup)
# The ToolGroup class has an 'identifier' attribute as per the API definition
id_attr = 'identifier'
# Filter MCP toolgroups based on the identifier attribute
mcp_toolgroups = []
for tg in toolgroups:
if hasattr(tg, 'identifier') and isinstance(tg.identifier, str) and tg.identifier.startswith("mcp::"):
mcp_toolgroups.append(tg)
if mcp_toolgroups:
# Display MCP servers with unregister buttons
st.write("Click the button next to a server to unregister it:")
for tg in mcp_toolgroups:
col1, col2, col3, col4 = st.columns([3, 2, 3, 1])
with col1:
st.write(f"**{tg.identifier}**")
with col2:
st.write(tg.provider_id)
with col3:
if hasattr(tg, 'mcp_endpoint') and tg.mcp_endpoint:
st.write(tg.mcp_endpoint.uri)
else:
st.write("N/A")
with col4:
# Extract server name from identifier (remove "mcp::" prefix)
server_name = tg.identifier.replace("mcp::", "")
if st.button("Unregister", key=f"unregister_{server_name}"):
try:
# Call the unregister API
llama_stack_api.client.toolgroups.unregister(
toolgroup_id=tg.identifier,
timeout=5.0
)
# Also clean up provider data
if hasattr(llama_stack_api, "provider_data"):
provider_data = llama_stack_api.provider_data.copy()
# Remove MCP headers for this server if they exist
if "mcp_headers" in provider_data and hasattr(tg, 'mcp_endpoint') and tg.mcp_endpoint:
if tg.mcp_endpoint.uri in provider_data["mcp_headers"]:
del provider_data["mcp_headers"][tg.mcp_endpoint.uri]
# Remove server-specific tokens
keys_to_remove = [
f"{server_name}_api_key",
f"{server_name}_token",
f"{server_name}_mcp_token"
]
for key in keys_to_remove:
if key in provider_data:
del provider_data[key]
# Update the client with the modified provider data
if hasattr(llama_stack_api, "update_provider_data_dict"):
llama_stack_api.update_provider_data_dict(provider_data)
else:
# Fallback implementation
llama_stack_api.provider_data = provider_data
llama_stack_api.client = LlamaStackClient(
base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"),
provider_data=provider_data,
)
st.success(f"Successfully unregistered {tg.identifier}")
st.rerun() # Refresh the page to update the list
except Exception as e:
st.error(f"Failed to unregister {tg.identifier}: {str(e)}")
else:
st.info("No MCP toolgroups registered yet.")
else:
st.info("No toolgroups found.")
else:
st.info("No toolgroups returned from the API.")
except Exception as e:
if "timeout" in str(e).lower():
st.warning("Listing toolgroups timed out. The server might be busy or unreachable.")
else:
st.error(f"Failed to list toolgroups: {str(e)}")
if __name__ == "__main__":
main()

View file

@ -80,9 +80,14 @@ def tool_chat_page():
total_tools = 0
for toolgroup_id in toolgroup_selection:
tools = client.tools.list(toolgroup_id=toolgroup_id)
grouped_tools[toolgroup_id] = [tool.identifier for tool in tools]
total_tools += len(tools)
try:
# Add a timeout of 5 seconds to prevent UI freezing
tools = client.tools.list(toolgroup_id=toolgroup_id, timeout=5.0)
grouped_tools[toolgroup_id] = [tool.identifier for tool in tools]
total_tools += len(tools)
except Exception as e:
st.warning(f"Failed to list tools for {toolgroup_id}: {str(e)}")
grouped_tools[toolgroup_id] = []
st.markdown(f"Active Tools: 🛠 {total_tools}")
@ -126,25 +131,56 @@ def tool_chat_page():
@st.cache_resource
def create_agent():
if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT:
return ReActAgent(
client=client,
model=model,
tools=toolgroup_selection,
response_format={
"type": "json_schema",
"json_schema": ReActOutput.model_json_schema(),
},
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
)
else:
return Agent(
client,
model=model,
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
tools=toolgroup_selection,
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
try:
if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT:
return ReActAgent(
client=client,
model=model,
tools=toolgroup_selection,
response_format={
"type": "json_schema",
"json_schema": ReActOutput.model_json_schema(),
},
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
)
else:
return Agent(
client,
model=model,
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
tools=toolgroup_selection,
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
)
except Exception as e:
# Log the error
st.error(f"Failed to create agent: {str(e)}")
# Create a fallback agent without tools
st.warning(
"Creating a fallback agent without tools due to an error. "
"Some functionality may be limited. Try refreshing the page or selecting different tools."
)
# Return a basic agent without tools
if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT:
return ReActAgent(
client=client,
model=model,
tools=[], # No tools
response_format={
"type": "json_schema",
"json_schema": ReActOutput.model_json_schema(),
},
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
)
else:
return Agent(
client,
model=model,
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
tools=[], # No tools
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
)
st.session_state.agent_type = agent_type