From 4e19f15bca9c87e4a597a121136392ab8dde9a3f Mon Sep 17 00:00:00 2001 From: Kai Wu Date: Sat, 2 Aug 2025 15:54:52 -0700 Subject: [PATCH] add mcp --- llama_stack/distribution/ui/app.py | 8 +- llama_stack/distribution/ui/modules/api.py | 82 +++- .../ui/page/distribution/mcp_servers.py | 394 ++++++++++++++++++ .../ui/page/distribution/providers.py | 4 +- 4 files changed, 478 insertions(+), 10 deletions(-) create mode 100644 llama_stack/distribution/ui/page/distribution/mcp_servers.py diff --git a/llama_stack/distribution/ui/app.py b/llama_stack/distribution/ui/app.py index 441f65d20..7a6b8b887 100644 --- a/llama_stack/distribution/ui/app.py +++ b/llama_stack/distribution/ui/app.py @@ -34,6 +34,12 @@ def main(): icon="🔍", default=False, ) + mcp_servers_page = st.Page( + "page/distribution/mcp_servers.py", + title="MCP Servers", + icon="🔌", + default=False, + ) pg = st.navigation( { @@ -44,7 +50,7 @@ def main(): application_evaluation_page, native_evaluation_page, ], - "Inspect": [provider_page, resources_page], + "Inspect": [provider_page, resources_page, mcp_servers_page], }, expanded=False, ) diff --git a/llama_stack/distribution/ui/modules/api.py b/llama_stack/distribution/ui/modules/api.py index aa1584b5e..d36b3479a 100644 --- a/llama_stack/distribution/ui/modules/api.py +++ b/llama_stack/distribution/ui/modules/api.py @@ -4,12 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json import os import streamlit as st from llama_stack_client import LlamaStackClient +# Constants +MCP_CONFIG_DIR = os.path.expanduser("~/.llama_stack/mcp_servers") +MCP_CONFIG_FILE = os.path.join(MCP_CONFIG_DIR, "config.json") + class LlamaStackApi: def __init__(self): # Initialize provider data from environment variables @@ -24,22 +29,83 @@ class LlamaStackApi: # Check if we have any API keys stored in session state if st.session_state.get("tavily_search_api_key"): self.provider_data["tavily_search_api_key"] = st.session_state.get("tavily_search_api_key") + + # Load MCP server configurations + self.mcp_servers = self.load_mcp_config() # Initialize the client - self.client = LlamaStackClient( - base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"), - provider_data=self.provider_data, - ) + try: + # Try to initialize with MCP servers support + self.client = LlamaStackClient( + base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"), + provider_data=self.provider_data, + mcp_servers=self.mcp_servers.get("servers", {}), + ) + except TypeError: + # Fall back to initialization without MCP servers if not supported + self.client = LlamaStackClient( + base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"), + provider_data=self.provider_data, + ) + st.warning("MCP servers support not available in the current LlamaStackClient version.") + def load_mcp_config(self): + """Load MCP server configurations from file.""" + if os.path.exists(MCP_CONFIG_FILE): + try: + with open(MCP_CONFIG_FILE, "r") as f: + return json.load(f) + except json.JSONDecodeError: + st.error("Error loading MCP server configuration file.") + + # Default empty configuration + return { + "servers": {}, + "inputs": [] + } + def update_provider_data(self, key, value): """Update a specific provider data key and reinitialize the client""" self.provider_data[key] = value + self.update_provider_data_dict(self.provider_data) + + def update_provider_data_dict(self, provider_data): + """Update the provider data dictionary and reinitialize the client""" + self.provider_data = provider_data # Reinitialize the client with updated provider data - self.client = LlamaStackClient( - base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"), - provider_data=self.provider_data, - ) + try: + # Try to initialize with MCP servers support + self.client = LlamaStackClient( + base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"), + provider_data=self.provider_data, + mcp_servers=self.mcp_servers.get("servers", {}), + ) + except TypeError: + # Fall back to initialization without MCP servers if not supported + self.client = LlamaStackClient( + base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"), + provider_data=self.provider_data, + ) + + def update_mcp_servers(self): + """Reload MCP server configurations and reinitialize the client""" + self.mcp_servers = self.load_mcp_config() + + # Reinitialize the client with updated MCP servers + try: + # Try to initialize with MCP servers support + self.client = LlamaStackClient( + base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"), + provider_data=self.provider_data, + mcp_servers=self.mcp_servers.get("servers", {}), + ) + except TypeError: + # Fall back to initialization without MCP servers if not supported + self.client = LlamaStackClient( + base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"), + provider_data=self.provider_data, + ) def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: dict | None): """Run scoring on a single row""" diff --git a/llama_stack/distribution/ui/page/distribution/mcp_servers.py b/llama_stack/distribution/ui/page/distribution/mcp_servers.py new file mode 100644 index 000000000..9ae955dc8 --- /dev/null +++ b/llama_stack/distribution/ui/page/distribution/mcp_servers.py @@ -0,0 +1,394 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import os +import streamlit as st +from typing import Dict, List, Optional, Any + +from llama_stack.distribution.ui.modules.api import llama_stack_api + +# Constants +CONFIG_DIR = os.path.expanduser("~/.llama_stack/mcp_servers") +CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json") + +# Ensure config directory exists +os.makedirs(CONFIG_DIR, exist_ok=True) + +def load_config() -> Dict[str, Any]: + """Load MCP server configurations from file.""" + if os.path.exists(CONFIG_FILE): + try: + with open(CONFIG_FILE, "r") as f: + return json.load(f) + except json.JSONDecodeError: + st.error("Error loading MCP server configuration file. Using default configuration.") + + # Default empty configuration + return { + "servers": {}, + "inputs": [] + } + +def register_mcp_servers(config: Dict[str, Any]) -> None: + """Register MCP servers as toolgroups with the LlamaStackClient.""" + servers = config.get("servers", {}) + inputs = config.get("inputs", []) + + # Process inputs to resolve values + input_values = {} + for input_config in inputs: + input_id = input_config.get("id") + if input_id: + # Get value from session state if available + input_values[input_id] = st.session_state.get(f"input_value_{input_id}", "") + + # Update provider data with MCP headers + mcp_headers = {} + + # Register each server as a toolgroup + for server_name, server_config in servers.items(): + url = server_config.get("url", "") + if not url: + continue + + try: + # Register the MCP server as a toolgroup + toolgroup_id = f"mcp::{server_name}" + + # Register the toolgroup + llama_stack_api.client.toolgroups.register( + toolgroup_id=toolgroup_id, + provider_id="model-context-protocol", + mcp_endpoint=url, + ) + + # Process headers for this server + headers = server_config.get("headers", {}) + processed_headers = {} + + for header_key, header_value in headers.items(): + # Process input references in header values + if isinstance(header_value, str) and "${input:" in header_value: + # Extract input ID from ${input:input_id} format + input_id = header_value.split("${input:")[1].split("}")[0] + if input_id in input_values: + processed_headers[header_key] = input_values[input_id] + else: + processed_headers[header_key] = header_value + + # Add headers to mcp_headers if there are any + if processed_headers: + mcp_headers[url] = processed_headers + + except Exception as e: + st.error(f"Failed to register MCP server '{server_name}': {str(e)}") + + # Update provider data with MCP headers if there are any + if mcp_headers: + provider_data = llama_stack_api.provider_data.copy() + provider_data["mcp_headers"] = mcp_headers + llama_stack_api.update_provider_data_dict(provider_data) + +def save_config(config: Dict[str, Any]) -> None: + """Save MCP server configurations to file.""" + with open(CONFIG_FILE, "w") as f: + json.dump(config, f, indent=2) + + # Register MCP servers as toolgroups + register_mcp_servers(config) + + st.success("MCP server configuration saved successfully!") + +def render_server_config(server_name: str, server_config: Dict[str, Any], config: Dict[str, Any]) -> Dict[str, Any]: + """Render and edit configuration for a specific MCP server.""" + st.subheader(f"Server: {server_name}") + + # Server type + server_type = st.selectbox( + "Server Type", + ["http", "websocket"], + index=0 if server_config.get("type", "http") == "http" else 1, + key=f"type_{server_name}" + ) + + # Server URL + url = st.text_input( + "Server URL", + value=server_config.get("url", ""), + key=f"url_{server_name}" + ) + + # Headers + st.write("Headers:") + headers = server_config.get("headers", {}) + headers_container = st.container() + + with headers_container: + # Display existing headers + for header_key, header_value in list(headers.items()): + col1, col2, col3 = st.columns([3, 6, 1]) + with col1: + st.text(header_key) + with col2: + # Check if this is a reference to an input + if isinstance(header_value, str) and "${input:" in header_value: + st.text(header_value) + else: + st.text("********" if "token" in header_key.lower() or "auth" in header_key.lower() else header_value) + with col3: + if st.button("🗑️", key=f"delete_header_{server_name}_{header_key}"): + del headers[header_key] + + # Add new header + st.write("Add Header:") + header_col1, header_col2 = st.columns([1, 1]) + new_header_key = header_col1.text_input("Key", key=f"new_header_key_{server_name}") + new_header_value = header_col2.text_input("Value", key=f"new_header_value_{server_name}") + + if st.button("Add Header", key=f"add_header_{server_name}"): + if new_header_key and new_header_value: + headers[new_header_key] = new_header_value + st.experimental_rerun() + + # Construct updated server config + updated_server_config = { + "type": server_type, + "url": url, + "headers": headers + } + + return updated_server_config + +def render_input_config(input_config: Dict[str, Any], index: int) -> Dict[str, Any]: + """Render and edit configuration for an input field.""" + st.subheader(f"Input: {input_config.get('id', '')}") + + col1, col2 = st.columns([1, 1]) + + input_type = col1.selectbox( + "Type", + ["promptString"], + index=0, + key=f"input_type_{index}" + ) + + input_id = col2.text_input( + "ID", + value=input_config.get("id", ""), + key=f"input_id_{index}" + ) + + description = st.text_input( + "Description", + value=input_config.get("description", ""), + key=f"input_desc_{index}" + ) + + is_password = st.checkbox( + "Password Field", + value=input_config.get("password", False), + key=f"input_password_{index}" + ) + + # Construct updated input config + updated_input_config = { + "type": input_type, + "id": input_id, + "description": description, + "password": is_password + } + + return updated_input_config + +def main(): + st.title("MCP Servers Configuration") + + st.markdown(""" + Configure Model Control Protocol (MCP) servers to integrate with external AI services. + MCP is a protocol for controlling AI models through a standardized API. + + MCP servers are registered as toolgroups with the ID format `mcp::{server_name}`. + """) + + # Load existing configuration + config = load_config() + + # Tabs for different sections + tab1, tab2, tab3, tab4 = st.tabs(["Servers", "Inputs", "Input Values", "JSON Config"]) + + # Servers Tab + with tab1: + st.header("Configured Servers") + + # List existing servers + servers = config.get("servers", {}) + if not servers: + st.info("No MCP servers configured yet. Add a new server below.") + + # Server selection or creation + server_options = list(servers.keys()) + ["+ Add New Server"] + selected_server = st.selectbox("Select Server", server_options) + + if selected_server == "+ Add New Server": + new_server_name = st.text_input("New Server Name") + if new_server_name and st.button("Create Server"): + if new_server_name in servers: + st.error(f"Server '{new_server_name}' already exists.") + else: + servers[new_server_name] = { + "type": "http", + "url": "", + "headers": {} + } + st.experimental_rerun() + elif selected_server in servers: + # Edit existing server + updated_config = render_server_config(selected_server, servers[selected_server], config) + + col1, col2 = st.columns([1, 1]) + if col1.button("Update Server", key=f"update_{selected_server}"): + servers[selected_server] = updated_config + save_config(config) + + if col2.button("Delete Server", key=f"delete_{selected_server}"): + del servers[selected_server] + save_config(config) + st.experimental_rerun() + + # Inputs Tab + with tab2: + st.header("Input Configurations") + + inputs = config.get("inputs", []) + if not inputs: + st.info("No input configurations defined yet. Add a new input below.") + + # Input selection or creation + input_options = [f"{i.get('id', f'Input {idx}')} ({i.get('type', 'promptString')})" for idx, i in enumerate(inputs)] + input_options.append("+ Add New Input") + + selected_input_option = st.selectbox("Select Input", input_options) + + if selected_input_option == "+ Add New Input": + if st.button("Create Input"): + inputs.append({ + "type": "promptString", + "id": f"input_{len(inputs)}", + "description": "", + "password": False + }) + save_config(config) + st.experimental_rerun() + else: + # Edit existing input + selected_idx = input_options.index(selected_input_option) + if selected_idx < len(inputs): + updated_input = render_input_config(inputs[selected_idx], selected_idx) + + col1, col2 = st.columns([1, 1]) + if col1.button("Update Input", key=f"update_input_{selected_idx}"): + inputs[selected_idx] = updated_input + save_config(config) + + if col2.button("Delete Input", key=f"delete_input_{selected_idx}"): + inputs.pop(selected_idx) + save_config(config) + st.experimental_rerun() + + # Input Values Tab + with tab3: + st.header("Input Values") + + inputs = config.get("inputs", []) + if not inputs: + st.info("No input configurations defined yet. Add inputs in the Inputs tab.") + else: + st.write("Enter values for the configured inputs:") + + for input_config in inputs: + input_id = input_config.get("id", "") + description = input_config.get("description", "") + is_password = input_config.get("password", False) + + # Get value from session state if available + current_value = st.session_state.get(f"input_value_{input_id}", "") + + # Input field for value + value = st.text_input( + f"{description} ({input_id})", + value=current_value, + type="password" if is_password else "default", + key=f"input_value_{input_id}" + ) + + if st.button("Save Input Values"): + # Values are automatically saved to session state + # Register MCP servers to apply the new values + register_mcp_servers(config) + st.success("Input values saved successfully!") + + # JSON Config Tab + with tab4: + st.header("Raw JSON Configuration") + + # Display and edit raw JSON + json_str = json.dumps(config, indent=2) + edited_json = st.text_area("Edit JSON Configuration", json_str, height=400) + + if st.button("Update from JSON"): + try: + updated_config = json.loads(edited_json) + save_config(updated_config) + st.experimental_rerun() + except json.JSONDecodeError as e: + st.error(f"Invalid JSON: {str(e)}") + + # Example configuration and usage + with st.expander("Example Configuration and Usage"): + st.markdown(""" + ### Example Configuration + + ```json + { + "servers": { + "github": { + "type": "http", + "url": "https://api.githubcopilot.com/mcp/", + "headers": { + "Authorization": "Bearer ${input:github_mcp_pat}" + } + } + }, + "inputs": [ + { + "type": "promptString", + "id": "github_mcp_pat", + "description": "GitHub Personal Access Token", + "password": true + } + ] + } + ``` + + ### Usage in Code + + Once registered, you can use the MCP server in your code: + + ```python + from llama_stack_client import Agent + + agent = Agent( + model="llama-3-70b-instruct", + tools=["mcp::github"], # Use the registered MCP server + ) + + agent.create_turn("Use GitHub Copilot to help me write a function") + ``` + """) + +if __name__ == "__main__": + main() diff --git a/llama_stack/distribution/ui/page/distribution/providers.py b/llama_stack/distribution/ui/page/distribution/providers.py index 97d5ed4fd..57fdfee40 100644 --- a/llama_stack/distribution/ui/page/distribution/providers.py +++ b/llama_stack/distribution/ui/page/distribution/providers.py @@ -43,7 +43,9 @@ def providers(): llama_stack_api.update_provider_data("tavily_search_api_key", tavily_search_api_key) else: # Fallback implementation if method doesn't exist - llama_stack_api.provider_data = llama_stack_api.provider_data or {} + # Initialize provider_data if it doesn't exist + if not hasattr(llama_stack_api, "provider_data"): + llama_stack_api.provider_data = {} llama_stack_api.provider_data["tavily_search_api_key"] = tavily_search_api_key # Reinitialize the client with updated provider data llama_stack_api.client = LlamaStackClient(