mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 14:08:00 +00:00
ready
This commit is contained in:
parent
4e19f15bca
commit
dcc47c2008
6 changed files with 351 additions and 8927 deletions
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load diff
|
@ -39,7 +39,7 @@ spec:
|
||||||
image: vllm/vllm-openai:latest
|
image: vllm/vllm-openai:latest
|
||||||
command: ["/bin/sh", "-c"]
|
command: ["/bin/sh", "-c"]
|
||||||
args:
|
args:
|
||||||
- "vllm serve ${INFERENCE_MODEL} --enforce-eager --max-model-len 8192 --gpu-memory-utilization 0.7 --enable-auto-tool-choice --tool-call-parser llama3_json --max-num-seqs 4 --port 8001"
|
- "vllm serve ${INFERENCE_MODEL} --enforce-eager --max-model-len 100000 --gpu-memory-utilization 0.9 --enable-auto-tool-choice --tool-call-parser llama3_json --max-num-seqs 2 --port 8001"
|
||||||
env:
|
env:
|
||||||
- name: INFERENCE_MODEL
|
- name: INFERENCE_MODEL
|
||||||
value: "${INFERENCE_MODEL}"
|
value: "${INFERENCE_MODEL}"
|
||||||
|
|
|
@ -7,202 +7,11 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from typing import Dict, List, Optional, Any
|
|
||||||
|
|
||||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
from llama_stack_client.types.toolgroup_register_params import McpEndpoint
|
||||||
|
|
||||||
# Constants
|
|
||||||
CONFIG_DIR = os.path.expanduser("~/.llama_stack/mcp_servers")
|
|
||||||
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
|
|
||||||
|
|
||||||
# Ensure config directory exists
|
|
||||||
os.makedirs(CONFIG_DIR, exist_ok=True)
|
|
||||||
|
|
||||||
def load_config() -> Dict[str, Any]:
|
|
||||||
"""Load MCP server configurations from file."""
|
|
||||||
if os.path.exists(CONFIG_FILE):
|
|
||||||
try:
|
|
||||||
with open(CONFIG_FILE, "r") as f:
|
|
||||||
return json.load(f)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
st.error("Error loading MCP server configuration file. Using default configuration.")
|
|
||||||
|
|
||||||
# Default empty configuration
|
|
||||||
return {
|
|
||||||
"servers": {},
|
|
||||||
"inputs": []
|
|
||||||
}
|
|
||||||
|
|
||||||
def register_mcp_servers(config: Dict[str, Any]) -> None:
|
|
||||||
"""Register MCP servers as toolgroups with the LlamaStackClient."""
|
|
||||||
servers = config.get("servers", {})
|
|
||||||
inputs = config.get("inputs", [])
|
|
||||||
|
|
||||||
# Process inputs to resolve values
|
|
||||||
input_values = {}
|
|
||||||
for input_config in inputs:
|
|
||||||
input_id = input_config.get("id")
|
|
||||||
if input_id:
|
|
||||||
# Get value from session state if available
|
|
||||||
input_values[input_id] = st.session_state.get(f"input_value_{input_id}", "")
|
|
||||||
|
|
||||||
# Update provider data with MCP headers
|
|
||||||
mcp_headers = {}
|
|
||||||
|
|
||||||
# Register each server as a toolgroup
|
|
||||||
for server_name, server_config in servers.items():
|
|
||||||
url = server_config.get("url", "")
|
|
||||||
if not url:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Register the MCP server as a toolgroup
|
|
||||||
toolgroup_id = f"mcp::{server_name}"
|
|
||||||
|
|
||||||
# Register the toolgroup
|
|
||||||
llama_stack_api.client.toolgroups.register(
|
|
||||||
toolgroup_id=toolgroup_id,
|
|
||||||
provider_id="model-context-protocol",
|
|
||||||
mcp_endpoint=url,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process headers for this server
|
|
||||||
headers = server_config.get("headers", {})
|
|
||||||
processed_headers = {}
|
|
||||||
|
|
||||||
for header_key, header_value in headers.items():
|
|
||||||
# Process input references in header values
|
|
||||||
if isinstance(header_value, str) and "${input:" in header_value:
|
|
||||||
# Extract input ID from ${input:input_id} format
|
|
||||||
input_id = header_value.split("${input:")[1].split("}")[0]
|
|
||||||
if input_id in input_values:
|
|
||||||
processed_headers[header_key] = input_values[input_id]
|
|
||||||
else:
|
|
||||||
processed_headers[header_key] = header_value
|
|
||||||
|
|
||||||
# Add headers to mcp_headers if there are any
|
|
||||||
if processed_headers:
|
|
||||||
mcp_headers[url] = processed_headers
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
st.error(f"Failed to register MCP server '{server_name}': {str(e)}")
|
|
||||||
|
|
||||||
# Update provider data with MCP headers if there are any
|
|
||||||
if mcp_headers:
|
|
||||||
provider_data = llama_stack_api.provider_data.copy()
|
|
||||||
provider_data["mcp_headers"] = mcp_headers
|
|
||||||
llama_stack_api.update_provider_data_dict(provider_data)
|
|
||||||
|
|
||||||
def save_config(config: Dict[str, Any]) -> None:
|
|
||||||
"""Save MCP server configurations to file."""
|
|
||||||
with open(CONFIG_FILE, "w") as f:
|
|
||||||
json.dump(config, f, indent=2)
|
|
||||||
|
|
||||||
# Register MCP servers as toolgroups
|
|
||||||
register_mcp_servers(config)
|
|
||||||
|
|
||||||
st.success("MCP server configuration saved successfully!")
|
|
||||||
|
|
||||||
def render_server_config(server_name: str, server_config: Dict[str, Any], config: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""Render and edit configuration for a specific MCP server."""
|
|
||||||
st.subheader(f"Server: {server_name}")
|
|
||||||
|
|
||||||
# Server type
|
|
||||||
server_type = st.selectbox(
|
|
||||||
"Server Type",
|
|
||||||
["http", "websocket"],
|
|
||||||
index=0 if server_config.get("type", "http") == "http" else 1,
|
|
||||||
key=f"type_{server_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Server URL
|
|
||||||
url = st.text_input(
|
|
||||||
"Server URL",
|
|
||||||
value=server_config.get("url", ""),
|
|
||||||
key=f"url_{server_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Headers
|
|
||||||
st.write("Headers:")
|
|
||||||
headers = server_config.get("headers", {})
|
|
||||||
headers_container = st.container()
|
|
||||||
|
|
||||||
with headers_container:
|
|
||||||
# Display existing headers
|
|
||||||
for header_key, header_value in list(headers.items()):
|
|
||||||
col1, col2, col3 = st.columns([3, 6, 1])
|
|
||||||
with col1:
|
|
||||||
st.text(header_key)
|
|
||||||
with col2:
|
|
||||||
# Check if this is a reference to an input
|
|
||||||
if isinstance(header_value, str) and "${input:" in header_value:
|
|
||||||
st.text(header_value)
|
|
||||||
else:
|
|
||||||
st.text("********" if "token" in header_key.lower() or "auth" in header_key.lower() else header_value)
|
|
||||||
with col3:
|
|
||||||
if st.button("🗑️", key=f"delete_header_{server_name}_{header_key}"):
|
|
||||||
del headers[header_key]
|
|
||||||
|
|
||||||
# Add new header
|
|
||||||
st.write("Add Header:")
|
|
||||||
header_col1, header_col2 = st.columns([1, 1])
|
|
||||||
new_header_key = header_col1.text_input("Key", key=f"new_header_key_{server_name}")
|
|
||||||
new_header_value = header_col2.text_input("Value", key=f"new_header_value_{server_name}")
|
|
||||||
|
|
||||||
if st.button("Add Header", key=f"add_header_{server_name}"):
|
|
||||||
if new_header_key and new_header_value:
|
|
||||||
headers[new_header_key] = new_header_value
|
|
||||||
st.experimental_rerun()
|
|
||||||
|
|
||||||
# Construct updated server config
|
|
||||||
updated_server_config = {
|
|
||||||
"type": server_type,
|
|
||||||
"url": url,
|
|
||||||
"headers": headers
|
|
||||||
}
|
|
||||||
|
|
||||||
return updated_server_config
|
|
||||||
|
|
||||||
def render_input_config(input_config: Dict[str, Any], index: int) -> Dict[str, Any]:
|
|
||||||
"""Render and edit configuration for an input field."""
|
|
||||||
st.subheader(f"Input: {input_config.get('id', '')}")
|
|
||||||
|
|
||||||
col1, col2 = st.columns([1, 1])
|
|
||||||
|
|
||||||
input_type = col1.selectbox(
|
|
||||||
"Type",
|
|
||||||
["promptString"],
|
|
||||||
index=0,
|
|
||||||
key=f"input_type_{index}"
|
|
||||||
)
|
|
||||||
|
|
||||||
input_id = col2.text_input(
|
|
||||||
"ID",
|
|
||||||
value=input_config.get("id", ""),
|
|
||||||
key=f"input_id_{index}"
|
|
||||||
)
|
|
||||||
|
|
||||||
description = st.text_input(
|
|
||||||
"Description",
|
|
||||||
value=input_config.get("description", ""),
|
|
||||||
key=f"input_desc_{index}"
|
|
||||||
)
|
|
||||||
|
|
||||||
is_password = st.checkbox(
|
|
||||||
"Password Field",
|
|
||||||
value=input_config.get("password", False),
|
|
||||||
key=f"input_password_{index}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Construct updated input config
|
|
||||||
updated_input_config = {
|
|
||||||
"type": input_type,
|
|
||||||
"id": input_id,
|
|
||||||
"description": description,
|
|
||||||
"password": is_password
|
|
||||||
}
|
|
||||||
|
|
||||||
return updated_input_config
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
st.title("MCP Servers Configuration")
|
st.title("MCP Servers Configuration")
|
||||||
|
@ -214,181 +23,318 @@ def main():
|
||||||
MCP servers are registered as toolgroups with the ID format `mcp::{server_name}`.
|
MCP servers are registered as toolgroups with the ID format `mcp::{server_name}`.
|
||||||
""")
|
""")
|
||||||
|
|
||||||
# Load existing configuration
|
# MCP Server Configuration
|
||||||
config = load_config()
|
st.header("MCP Server Configuration")
|
||||||
|
|
||||||
# Tabs for different sections
|
# Create a form for MCP configuration
|
||||||
tab1, tab2, tab3, tab4 = st.tabs(["Servers", "Inputs", "Input Values", "JSON Config"])
|
with st.form("mcp_server_form"):
|
||||||
|
# Server name
|
||||||
# Servers Tab
|
server_name = st.text_input(
|
||||||
with tab1:
|
"Server Name",
|
||||||
st.header("Configured Servers")
|
value="github",
|
||||||
|
help="A unique name for this MCP server. Will be used in the toolgroup ID as mcp::{name}."
|
||||||
|
)
|
||||||
|
|
||||||
# List existing servers
|
# MCP URL
|
||||||
servers = config.get("servers", {})
|
mcp_url = st.text_input(
|
||||||
if not servers:
|
"MCP URL",
|
||||||
st.info("No MCP servers configured yet. Add a new server below.")
|
value="https://api.githubcopilot.com/mcp/",
|
||||||
|
help="The URL of the MCP server."
|
||||||
|
)
|
||||||
|
|
||||||
# Server selection or creation
|
# Get the current value from session state
|
||||||
server_options = list(servers.keys()) + ["+ Add New Server"]
|
api_token = st.session_state.get(f"mcp_token_{server_name}", "")
|
||||||
selected_server = st.selectbox("Select Server", server_options)
|
|
||||||
|
|
||||||
if selected_server == "+ Add New Server":
|
# Input field for API Bearer Token
|
||||||
new_server_name = st.text_input("New Server Name")
|
mcp_token = st.text_input(
|
||||||
if new_server_name and st.button("Create Server"):
|
"API Bearer Token",
|
||||||
if new_server_name in servers:
|
value=api_token,
|
||||||
st.error(f"Server '{new_server_name}' already exists.")
|
type="password",
|
||||||
else:
|
help="Enter your API Bearer Token. For GitHub Copilot, this should be a GitHub Personal Access Token with Copilot scope."
|
||||||
servers[new_server_name] = {
|
)
|
||||||
"type": "http",
|
|
||||||
"url": "",
|
|
||||||
"headers": {}
|
|
||||||
}
|
|
||||||
st.experimental_rerun()
|
|
||||||
elif selected_server in servers:
|
|
||||||
# Edit existing server
|
|
||||||
updated_config = render_server_config(selected_server, servers[selected_server], config)
|
|
||||||
|
|
||||||
col1, col2 = st.columns([1, 1])
|
|
||||||
if col1.button("Update Server", key=f"update_{selected_server}"):
|
|
||||||
servers[selected_server] = updated_config
|
|
||||||
save_config(config)
|
|
||||||
|
|
||||||
if col2.button("Delete Server", key=f"delete_{selected_server}"):
|
|
||||||
del servers[selected_server]
|
|
||||||
save_config(config)
|
|
||||||
st.experimental_rerun()
|
|
||||||
|
|
||||||
# Inputs Tab
|
|
||||||
with tab2:
|
|
||||||
st.header("Input Configurations")
|
|
||||||
|
|
||||||
inputs = config.get("inputs", [])
|
|
||||||
if not inputs:
|
|
||||||
st.info("No input configurations defined yet. Add a new input below.")
|
|
||||||
|
|
||||||
# Input selection or creation
|
# Submit button
|
||||||
input_options = [f"{i.get('id', f'Input {idx}')} ({i.get('type', 'promptString')})" for idx, i in enumerate(inputs)]
|
submit_button = st.form_submit_button("Save Configuration")
|
||||||
input_options.append("+ Add New Input")
|
|
||||||
|
|
||||||
selected_input_option = st.selectbox("Select Input", input_options)
|
if submit_button:
|
||||||
|
if not server_name:
|
||||||
if selected_input_option == "+ Add New Input":
|
st.error("Server name is required.")
|
||||||
if st.button("Create Input"):
|
elif not mcp_url:
|
||||||
inputs.append({
|
st.error("MCP URL is required.")
|
||||||
"type": "promptString",
|
else:
|
||||||
"id": f"input_{len(inputs)}",
|
# Store the token in session state
|
||||||
"description": "",
|
st.session_state[f"mcp_token_{server_name}"] = mcp_token
|
||||||
"password": False
|
|
||||||
})
|
|
||||||
save_config(config)
|
|
||||||
st.experimental_rerun()
|
|
||||||
else:
|
|
||||||
# Edit existing input
|
|
||||||
selected_idx = input_options.index(selected_input_option)
|
|
||||||
if selected_idx < len(inputs):
|
|
||||||
updated_input = render_input_config(inputs[selected_idx], selected_idx)
|
|
||||||
|
|
||||||
col1, col2 = st.columns([1, 1])
|
try:
|
||||||
if col1.button("Update Input", key=f"update_input_{selected_idx}"):
|
# Register the MCP server as a toolgroup
|
||||||
inputs[selected_idx] = updated_input
|
toolgroup_id = f"mcp::{server_name}"
|
||||||
save_config(config)
|
try:
|
||||||
|
llama_stack_api.client.toolgroups.register(
|
||||||
if col2.button("Delete Input", key=f"delete_input_{selected_idx}"):
|
toolgroup_id=toolgroup_id,
|
||||||
inputs.pop(selected_idx)
|
provider_id="model-context-protocol",
|
||||||
save_config(config)
|
mcp_endpoint=McpEndpoint(uri=mcp_url),
|
||||||
st.experimental_rerun()
|
timeout=10.0, # Set a reasonable timeout
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
if "timeout" in str(e).lower():
|
||||||
|
st.warning(f"Registration timed out, but configuration will still be saved. Error: {str(e)}")
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Update provider data with MCP headers
|
||||||
|
# Check if provider_data attribute exists
|
||||||
|
if not hasattr(llama_stack_api, "provider_data"):
|
||||||
|
llama_stack_api.provider_data = {}
|
||||||
|
|
||||||
|
provider_data = llama_stack_api.provider_data.copy()
|
||||||
|
if "mcp_headers" not in provider_data:
|
||||||
|
provider_data["mcp_headers"] = {}
|
||||||
|
|
||||||
|
# Add MCP headers
|
||||||
|
if mcp_token:
|
||||||
|
# Clean the token (remove 'Bearer ' prefix if present)
|
||||||
|
clean_token = mcp_token
|
||||||
|
if clean_token.lower().startswith("bearer "):
|
||||||
|
clean_token = clean_token[7:]
|
||||||
|
|
||||||
|
# Set the headers for this MCP server
|
||||||
|
# The key needs to be the exact URL used in the MCP endpoint
|
||||||
|
provider_data["mcp_headers"][mcp_url] = {
|
||||||
|
"Authorization": f"Bearer {clean_token}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Debug information
|
||||||
|
st.info(f"Set authentication headers for {mcp_url}: Bearer {clean_token[:4]}...")
|
||||||
|
|
||||||
|
# Also set the token directly in the provider_data using different formats
|
||||||
|
# This increases the chance that one of them will work
|
||||||
|
provider_data[f"{server_name}_api_key"] = clean_token
|
||||||
|
provider_data[f"{server_name}_token"] = clean_token
|
||||||
|
provider_data[f"{server_name}_mcp_token"] = clean_token
|
||||||
|
|
||||||
|
# Display the current provider_data for debugging
|
||||||
|
st.write("Current provider_data:")
|
||||||
|
|
||||||
|
# Mask the token for display
|
||||||
|
masked_token = ""
|
||||||
|
if clean_token:
|
||||||
|
if len(clean_token) > 8:
|
||||||
|
# Show first 2 and last 2 characters, mask the middle with asterisks
|
||||||
|
masked_token = f"Bearer {clean_token[:2]}{'*' * 6}{clean_token[-2:]}"
|
||||||
|
else:
|
||||||
|
# For short tokens, just show first 2 chars and asterisks
|
||||||
|
masked_token = f"Bearer {clean_token[:2]}{'*' * 4}"
|
||||||
|
|
||||||
|
# Create a sanitized version of mcp_headers for display
|
||||||
|
sanitized_headers = {}
|
||||||
|
for url, headers in provider_data.get("mcp_headers", {}).items():
|
||||||
|
sanitized_headers[url] = {}
|
||||||
|
for header_key, header_value in headers.items():
|
||||||
|
if header_key.lower() == "authorization" and isinstance(header_value, str):
|
||||||
|
if header_value.lower().startswith("bearer "):
|
||||||
|
token = header_value[7:]
|
||||||
|
if len(token) > 8:
|
||||||
|
sanitized_headers[url][header_key] = f"Bearer {token[:2]}{'*' * 6}{token[-2:]}"
|
||||||
|
else:
|
||||||
|
sanitized_headers[url][header_key] = f"Bearer {token[:2]}{'*' * 4}"
|
||||||
|
else:
|
||||||
|
sanitized_headers[url][header_key] = f"{header_value[:2]}{'*' * 6}"
|
||||||
|
else:
|
||||||
|
sanitized_headers[url][header_key] = header_value
|
||||||
|
|
||||||
|
st.json({
|
||||||
|
"mcp_headers": sanitized_headers,
|
||||||
|
f"{server_name}_api_key": masked_token
|
||||||
|
})
|
||||||
|
|
||||||
|
# Add a note about authentication
|
||||||
|
st.warning("""
|
||||||
|
**Important Authentication Notes:**
|
||||||
|
|
||||||
|
1. If you encounter authentication errors when using this MCP server,
|
||||||
|
try restarting the UI application to ensure the authentication headers are properly applied.
|
||||||
|
|
||||||
|
2. For GitHub Copilot, make sure you're using a valid GitHub Personal Access Token with the 'copilot' scope.
|
||||||
|
|
||||||
|
3. When using the MCP server in code, you may need to explicitly pass the authentication headers:
|
||||||
|
```python
|
||||||
|
import json
|
||||||
|
from llama_stack_client import Agent
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
model="llama-3-70b-instruct",
|
||||||
|
tools=["mcp::github"],
|
||||||
|
extra_headers={
|
||||||
|
"X-LlamaStack-Provider-Data": json.dumps({
|
||||||
|
"mcp_headers": {
|
||||||
|
"https://api.githubcopilot.com/mcp/": {
|
||||||
|
"Authorization": "Bearer YOUR_TOKEN_HERE"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Update the client with the new provider data
|
||||||
|
if hasattr(llama_stack_api, "update_provider_data_dict"):
|
||||||
|
llama_stack_api.update_provider_data_dict(provider_data)
|
||||||
|
else:
|
||||||
|
# Fallback implementation if method doesn't exist
|
||||||
|
llama_stack_api.provider_data = provider_data
|
||||||
|
# Reinitialize the client with updated provider data
|
||||||
|
llama_stack_api.client = LlamaStackClient(
|
||||||
|
base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"),
|
||||||
|
provider_data=provider_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
st.success(f"MCP server '{server_name}' configured successfully!")
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"Failed to configure MCP server '{server_name}': {str(e)}")
|
||||||
|
|
||||||
# Input Values Tab
|
# Usage example
|
||||||
with tab3:
|
with st.expander("Usage Example"):
|
||||||
st.header("Input Values")
|
|
||||||
|
|
||||||
inputs = config.get("inputs", [])
|
|
||||||
if not inputs:
|
|
||||||
st.info("No input configurations defined yet. Add inputs in the Inputs tab.")
|
|
||||||
else:
|
|
||||||
st.write("Enter values for the configured inputs:")
|
|
||||||
|
|
||||||
for input_config in inputs:
|
|
||||||
input_id = input_config.get("id", "")
|
|
||||||
description = input_config.get("description", "")
|
|
||||||
is_password = input_config.get("password", False)
|
|
||||||
|
|
||||||
# Get value from session state if available
|
|
||||||
current_value = st.session_state.get(f"input_value_{input_id}", "")
|
|
||||||
|
|
||||||
# Input field for value
|
|
||||||
value = st.text_input(
|
|
||||||
f"{description} ({input_id})",
|
|
||||||
value=current_value,
|
|
||||||
type="password" if is_password else "default",
|
|
||||||
key=f"input_value_{input_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if st.button("Save Input Values"):
|
|
||||||
# Values are automatically saved to session state
|
|
||||||
# Register MCP servers to apply the new values
|
|
||||||
register_mcp_servers(config)
|
|
||||||
st.success("Input values saved successfully!")
|
|
||||||
|
|
||||||
# JSON Config Tab
|
|
||||||
with tab4:
|
|
||||||
st.header("Raw JSON Configuration")
|
|
||||||
|
|
||||||
# Display and edit raw JSON
|
|
||||||
json_str = json.dumps(config, indent=2)
|
|
||||||
edited_json = st.text_area("Edit JSON Configuration", json_str, height=400)
|
|
||||||
|
|
||||||
if st.button("Update from JSON"):
|
|
||||||
try:
|
|
||||||
updated_config = json.loads(edited_json)
|
|
||||||
save_config(updated_config)
|
|
||||||
st.experimental_rerun()
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
st.error(f"Invalid JSON: {str(e)}")
|
|
||||||
|
|
||||||
# Example configuration and usage
|
|
||||||
with st.expander("Example Configuration and Usage"):
|
|
||||||
st.markdown("""
|
st.markdown("""
|
||||||
### Example Configuration
|
### Using MCP Servers in Code
|
||||||
|
|
||||||
```json
|
Once configured, you can use the MCP server in your code:
|
||||||
{
|
|
||||||
"servers": {
|
|
||||||
"github": {
|
|
||||||
"type": "http",
|
|
||||||
"url": "https://api.githubcopilot.com/mcp/",
|
|
||||||
"headers": {
|
|
||||||
"Authorization": "Bearer ${input:github_mcp_pat}"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"inputs": [
|
|
||||||
{
|
|
||||||
"type": "promptString",
|
|
||||||
"id": "github_mcp_pat",
|
|
||||||
"description": "GitHub Personal Access Token",
|
|
||||||
"password": true
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Usage in Code
|
|
||||||
|
|
||||||
Once registered, you can use the MCP server in your code:
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from llama_stack_client import Agent
|
from llama_stack_client import Agent
|
||||||
|
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
model="llama-3-70b-instruct",
|
model="llama-3-70b-instruct",
|
||||||
tools=["mcp::github"], # Use the registered MCP server
|
tools=["mcp::your_server_name"], # Use the registered MCP server
|
||||||
)
|
)
|
||||||
|
|
||||||
agent.create_turn("Use GitHub Copilot to help me write a function")
|
agent.create_turn("Use the MCP server to help me with a task")
|
||||||
|
```
|
||||||
|
|
||||||
|
### With Explicit Authentication Headers
|
||||||
|
|
||||||
|
If you encounter authentication issues, you can explicitly pass the authentication headers:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import json
|
||||||
|
from llama_stack_client import Agent
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
model="llama-3-70b-instruct",
|
||||||
|
tools=["mcp::github"],
|
||||||
|
extra_headers={
|
||||||
|
"X-LlamaStack-Provider-Data": json.dumps({
|
||||||
|
"mcp_headers": {
|
||||||
|
"https://api.githubcopilot.com/mcp/": {
|
||||||
|
"Authorization": "Bearer YOUR_TOKEN_HERE"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
)
|
||||||
```
|
```
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
# Display registered toolgroups
|
||||||
|
st.header("Registered MCP Toolgroups")
|
||||||
|
try:
|
||||||
|
toolgroups = llama_stack_api.client.toolgroups.list(timeout=5.0) # Set a reasonable timeout
|
||||||
|
|
||||||
|
# Debug the structure of the toolgroups
|
||||||
|
if toolgroups:
|
||||||
|
# Check the first toolgroup to determine its structure
|
||||||
|
first_toolgroup = toolgroups[0] if toolgroups else None
|
||||||
|
|
||||||
|
if first_toolgroup:
|
||||||
|
# Get the attribute names
|
||||||
|
attr_names = dir(first_toolgroup)
|
||||||
|
|
||||||
|
# The ToolGroup class has an 'identifier' attribute as per the API definition
|
||||||
|
id_attr = 'identifier'
|
||||||
|
|
||||||
|
# Filter MCP toolgroups based on the identifier attribute
|
||||||
|
mcp_toolgroups = []
|
||||||
|
for tg in toolgroups:
|
||||||
|
if hasattr(tg, 'identifier') and isinstance(tg.identifier, str) and tg.identifier.startswith("mcp::"):
|
||||||
|
mcp_toolgroups.append(tg)
|
||||||
|
|
||||||
|
if mcp_toolgroups:
|
||||||
|
# Display MCP servers with unregister buttons
|
||||||
|
st.write("Click the button next to a server to unregister it:")
|
||||||
|
|
||||||
|
for tg in mcp_toolgroups:
|
||||||
|
col1, col2, col3, col4 = st.columns([3, 2, 3, 1])
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
st.write(f"**{tg.identifier}**")
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
st.write(tg.provider_id)
|
||||||
|
|
||||||
|
with col3:
|
||||||
|
if hasattr(tg, 'mcp_endpoint') and tg.mcp_endpoint:
|
||||||
|
st.write(tg.mcp_endpoint.uri)
|
||||||
|
else:
|
||||||
|
st.write("N/A")
|
||||||
|
|
||||||
|
with col4:
|
||||||
|
# Extract server name from identifier (remove "mcp::" prefix)
|
||||||
|
server_name = tg.identifier.replace("mcp::", "")
|
||||||
|
if st.button("Unregister", key=f"unregister_{server_name}"):
|
||||||
|
try:
|
||||||
|
# Call the unregister API
|
||||||
|
llama_stack_api.client.toolgroups.unregister(
|
||||||
|
toolgroup_id=tg.identifier,
|
||||||
|
timeout=5.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Also clean up provider data
|
||||||
|
if hasattr(llama_stack_api, "provider_data"):
|
||||||
|
provider_data = llama_stack_api.provider_data.copy()
|
||||||
|
|
||||||
|
# Remove MCP headers for this server if they exist
|
||||||
|
if "mcp_headers" in provider_data and hasattr(tg, 'mcp_endpoint') and tg.mcp_endpoint:
|
||||||
|
if tg.mcp_endpoint.uri in provider_data["mcp_headers"]:
|
||||||
|
del provider_data["mcp_headers"][tg.mcp_endpoint.uri]
|
||||||
|
|
||||||
|
# Remove server-specific tokens
|
||||||
|
keys_to_remove = [
|
||||||
|
f"{server_name}_api_key",
|
||||||
|
f"{server_name}_token",
|
||||||
|
f"{server_name}_mcp_token"
|
||||||
|
]
|
||||||
|
for key in keys_to_remove:
|
||||||
|
if key in provider_data:
|
||||||
|
del provider_data[key]
|
||||||
|
|
||||||
|
# Update the client with the modified provider data
|
||||||
|
if hasattr(llama_stack_api, "update_provider_data_dict"):
|
||||||
|
llama_stack_api.update_provider_data_dict(provider_data)
|
||||||
|
else:
|
||||||
|
# Fallback implementation
|
||||||
|
llama_stack_api.provider_data = provider_data
|
||||||
|
llama_stack_api.client = LlamaStackClient(
|
||||||
|
base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"),
|
||||||
|
provider_data=provider_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
st.success(f"Successfully unregistered {tg.identifier}")
|
||||||
|
st.rerun() # Refresh the page to update the list
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"Failed to unregister {tg.identifier}: {str(e)}")
|
||||||
|
else:
|
||||||
|
st.info("No MCP toolgroups registered yet.")
|
||||||
|
else:
|
||||||
|
st.info("No toolgroups found.")
|
||||||
|
else:
|
||||||
|
st.info("No toolgroups returned from the API.")
|
||||||
|
except Exception as e:
|
||||||
|
if "timeout" in str(e).lower():
|
||||||
|
st.warning("Listing toolgroups timed out. The server might be busy or unreachable.")
|
||||||
|
else:
|
||||||
|
st.error(f"Failed to list toolgroups: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -80,9 +80,14 @@ def tool_chat_page():
|
||||||
total_tools = 0
|
total_tools = 0
|
||||||
|
|
||||||
for toolgroup_id in toolgroup_selection:
|
for toolgroup_id in toolgroup_selection:
|
||||||
tools = client.tools.list(toolgroup_id=toolgroup_id)
|
try:
|
||||||
grouped_tools[toolgroup_id] = [tool.identifier for tool in tools]
|
# Add a timeout of 5 seconds to prevent UI freezing
|
||||||
total_tools += len(tools)
|
tools = client.tools.list(toolgroup_id=toolgroup_id, timeout=5.0)
|
||||||
|
grouped_tools[toolgroup_id] = [tool.identifier for tool in tools]
|
||||||
|
total_tools += len(tools)
|
||||||
|
except Exception as e:
|
||||||
|
st.warning(f"Failed to list tools for {toolgroup_id}: {str(e)}")
|
||||||
|
grouped_tools[toolgroup_id] = []
|
||||||
|
|
||||||
st.markdown(f"Active Tools: 🛠 {total_tools}")
|
st.markdown(f"Active Tools: 🛠 {total_tools}")
|
||||||
|
|
||||||
|
@ -126,25 +131,56 @@ def tool_chat_page():
|
||||||
|
|
||||||
@st.cache_resource
|
@st.cache_resource
|
||||||
def create_agent():
|
def create_agent():
|
||||||
if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT:
|
try:
|
||||||
return ReActAgent(
|
if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT:
|
||||||
client=client,
|
return ReActAgent(
|
||||||
model=model,
|
client=client,
|
||||||
tools=toolgroup_selection,
|
model=model,
|
||||||
response_format={
|
tools=toolgroup_selection,
|
||||||
"type": "json_schema",
|
response_format={
|
||||||
"json_schema": ReActOutput.model_json_schema(),
|
"type": "json_schema",
|
||||||
},
|
"json_schema": ReActOutput.model_json_schema(),
|
||||||
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
},
|
||||||
)
|
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
||||||
else:
|
)
|
||||||
return Agent(
|
else:
|
||||||
client,
|
return Agent(
|
||||||
model=model,
|
client,
|
||||||
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
|
model=model,
|
||||||
tools=toolgroup_selection,
|
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
|
||||||
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
tools=toolgroup_selection,
|
||||||
|
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Log the error
|
||||||
|
st.error(f"Failed to create agent: {str(e)}")
|
||||||
|
|
||||||
|
# Create a fallback agent without tools
|
||||||
|
st.warning(
|
||||||
|
"Creating a fallback agent without tools due to an error. "
|
||||||
|
"Some functionality may be limited. Try refreshing the page or selecting different tools."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Return a basic agent without tools
|
||||||
|
if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT:
|
||||||
|
return ReActAgent(
|
||||||
|
client=client,
|
||||||
|
model=model,
|
||||||
|
tools=[], # No tools
|
||||||
|
response_format={
|
||||||
|
"type": "json_schema",
|
||||||
|
"json_schema": ReActOutput.model_json_schema(),
|
||||||
|
},
|
||||||
|
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return Agent(
|
||||||
|
client,
|
||||||
|
model=model,
|
||||||
|
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
|
||||||
|
tools=[], # No tools
|
||||||
|
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
||||||
|
)
|
||||||
|
|
||||||
st.session_state.agent_type = agent_type
|
st.session_state.agent_type = agent_type
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue