mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 14:08:00 +00:00
ready
This commit is contained in:
parent
4e19f15bca
commit
dcc47c2008
6 changed files with 351 additions and 8927 deletions
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load diff
|
@ -39,7 +39,7 @@ spec:
|
|||
image: vllm/vllm-openai:latest
|
||||
command: ["/bin/sh", "-c"]
|
||||
args:
|
||||
- "vllm serve ${INFERENCE_MODEL} --enforce-eager --max-model-len 8192 --gpu-memory-utilization 0.7 --enable-auto-tool-choice --tool-call-parser llama3_json --max-num-seqs 4 --port 8001"
|
||||
- "vllm serve ${INFERENCE_MODEL} --enforce-eager --max-model-len 100000 --gpu-memory-utilization 0.9 --enable-auto-tool-choice --tool-call-parser llama3_json --max-num-seqs 2 --port 8001"
|
||||
env:
|
||||
- name: INFERENCE_MODEL
|
||||
value: "${INFERENCE_MODEL}"
|
||||
|
|
|
@ -7,202 +7,11 @@
|
|||
import json
|
||||
import os
|
||||
import streamlit as st
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
from llama_stack_client import LlamaStackClient
|
||||
from llama_stack_client.types.toolgroup_register_params import McpEndpoint
|
||||
|
||||
# Constants
|
||||
CONFIG_DIR = os.path.expanduser("~/.llama_stack/mcp_servers")
|
||||
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
|
||||
|
||||
# Ensure config directory exists
|
||||
os.makedirs(CONFIG_DIR, exist_ok=True)
|
||||
|
||||
def load_config() -> Dict[str, Any]:
|
||||
"""Load MCP server configurations from file."""
|
||||
if os.path.exists(CONFIG_FILE):
|
||||
try:
|
||||
with open(CONFIG_FILE, "r") as f:
|
||||
return json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
st.error("Error loading MCP server configuration file. Using default configuration.")
|
||||
|
||||
# Default empty configuration
|
||||
return {
|
||||
"servers": {},
|
||||
"inputs": []
|
||||
}
|
||||
|
||||
def register_mcp_servers(config: Dict[str, Any]) -> None:
|
||||
"""Register MCP servers as toolgroups with the LlamaStackClient."""
|
||||
servers = config.get("servers", {})
|
||||
inputs = config.get("inputs", [])
|
||||
|
||||
# Process inputs to resolve values
|
||||
input_values = {}
|
||||
for input_config in inputs:
|
||||
input_id = input_config.get("id")
|
||||
if input_id:
|
||||
# Get value from session state if available
|
||||
input_values[input_id] = st.session_state.get(f"input_value_{input_id}", "")
|
||||
|
||||
# Update provider data with MCP headers
|
||||
mcp_headers = {}
|
||||
|
||||
# Register each server as a toolgroup
|
||||
for server_name, server_config in servers.items():
|
||||
url = server_config.get("url", "")
|
||||
if not url:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Register the MCP server as a toolgroup
|
||||
toolgroup_id = f"mcp::{server_name}"
|
||||
|
||||
# Register the toolgroup
|
||||
llama_stack_api.client.toolgroups.register(
|
||||
toolgroup_id=toolgroup_id,
|
||||
provider_id="model-context-protocol",
|
||||
mcp_endpoint=url,
|
||||
)
|
||||
|
||||
# Process headers for this server
|
||||
headers = server_config.get("headers", {})
|
||||
processed_headers = {}
|
||||
|
||||
for header_key, header_value in headers.items():
|
||||
# Process input references in header values
|
||||
if isinstance(header_value, str) and "${input:" in header_value:
|
||||
# Extract input ID from ${input:input_id} format
|
||||
input_id = header_value.split("${input:")[1].split("}")[0]
|
||||
if input_id in input_values:
|
||||
processed_headers[header_key] = input_values[input_id]
|
||||
else:
|
||||
processed_headers[header_key] = header_value
|
||||
|
||||
# Add headers to mcp_headers if there are any
|
||||
if processed_headers:
|
||||
mcp_headers[url] = processed_headers
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Failed to register MCP server '{server_name}': {str(e)}")
|
||||
|
||||
# Update provider data with MCP headers if there are any
|
||||
if mcp_headers:
|
||||
provider_data = llama_stack_api.provider_data.copy()
|
||||
provider_data["mcp_headers"] = mcp_headers
|
||||
llama_stack_api.update_provider_data_dict(provider_data)
|
||||
|
||||
def save_config(config: Dict[str, Any]) -> None:
|
||||
"""Save MCP server configurations to file."""
|
||||
with open(CONFIG_FILE, "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
# Register MCP servers as toolgroups
|
||||
register_mcp_servers(config)
|
||||
|
||||
st.success("MCP server configuration saved successfully!")
|
||||
|
||||
def render_server_config(server_name: str, server_config: Dict[str, Any], config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Render and edit configuration for a specific MCP server."""
|
||||
st.subheader(f"Server: {server_name}")
|
||||
|
||||
# Server type
|
||||
server_type = st.selectbox(
|
||||
"Server Type",
|
||||
["http", "websocket"],
|
||||
index=0 if server_config.get("type", "http") == "http" else 1,
|
||||
key=f"type_{server_name}"
|
||||
)
|
||||
|
||||
# Server URL
|
||||
url = st.text_input(
|
||||
"Server URL",
|
||||
value=server_config.get("url", ""),
|
||||
key=f"url_{server_name}"
|
||||
)
|
||||
|
||||
# Headers
|
||||
st.write("Headers:")
|
||||
headers = server_config.get("headers", {})
|
||||
headers_container = st.container()
|
||||
|
||||
with headers_container:
|
||||
# Display existing headers
|
||||
for header_key, header_value in list(headers.items()):
|
||||
col1, col2, col3 = st.columns([3, 6, 1])
|
||||
with col1:
|
||||
st.text(header_key)
|
||||
with col2:
|
||||
# Check if this is a reference to an input
|
||||
if isinstance(header_value, str) and "${input:" in header_value:
|
||||
st.text(header_value)
|
||||
else:
|
||||
st.text("********" if "token" in header_key.lower() or "auth" in header_key.lower() else header_value)
|
||||
with col3:
|
||||
if st.button("🗑️", key=f"delete_header_{server_name}_{header_key}"):
|
||||
del headers[header_key]
|
||||
|
||||
# Add new header
|
||||
st.write("Add Header:")
|
||||
header_col1, header_col2 = st.columns([1, 1])
|
||||
new_header_key = header_col1.text_input("Key", key=f"new_header_key_{server_name}")
|
||||
new_header_value = header_col2.text_input("Value", key=f"new_header_value_{server_name}")
|
||||
|
||||
if st.button("Add Header", key=f"add_header_{server_name}"):
|
||||
if new_header_key and new_header_value:
|
||||
headers[new_header_key] = new_header_value
|
||||
st.experimental_rerun()
|
||||
|
||||
# Construct updated server config
|
||||
updated_server_config = {
|
||||
"type": server_type,
|
||||
"url": url,
|
||||
"headers": headers
|
||||
}
|
||||
|
||||
return updated_server_config
|
||||
|
||||
def render_input_config(input_config: Dict[str, Any], index: int) -> Dict[str, Any]:
|
||||
"""Render and edit configuration for an input field."""
|
||||
st.subheader(f"Input: {input_config.get('id', '')}")
|
||||
|
||||
col1, col2 = st.columns([1, 1])
|
||||
|
||||
input_type = col1.selectbox(
|
||||
"Type",
|
||||
["promptString"],
|
||||
index=0,
|
||||
key=f"input_type_{index}"
|
||||
)
|
||||
|
||||
input_id = col2.text_input(
|
||||
"ID",
|
||||
value=input_config.get("id", ""),
|
||||
key=f"input_id_{index}"
|
||||
)
|
||||
|
||||
description = st.text_input(
|
||||
"Description",
|
||||
value=input_config.get("description", ""),
|
||||
key=f"input_desc_{index}"
|
||||
)
|
||||
|
||||
is_password = st.checkbox(
|
||||
"Password Field",
|
||||
value=input_config.get("password", False),
|
||||
key=f"input_password_{index}"
|
||||
)
|
||||
|
||||
# Construct updated input config
|
||||
updated_input_config = {
|
||||
"type": input_type,
|
||||
"id": input_id,
|
||||
"description": description,
|
||||
"password": is_password
|
||||
}
|
||||
|
||||
return updated_input_config
|
||||
|
||||
def main():
|
||||
st.title("MCP Servers Configuration")
|
||||
|
@ -214,181 +23,318 @@ def main():
|
|||
MCP servers are registered as toolgroups with the ID format `mcp::{server_name}`.
|
||||
""")
|
||||
|
||||
# Load existing configuration
|
||||
config = load_config()
|
||||
# MCP Server Configuration
|
||||
st.header("MCP Server Configuration")
|
||||
|
||||
# Tabs for different sections
|
||||
tab1, tab2, tab3, tab4 = st.tabs(["Servers", "Inputs", "Input Values", "JSON Config"])
|
||||
# Create a form for MCP configuration
|
||||
with st.form("mcp_server_form"):
|
||||
# Server name
|
||||
server_name = st.text_input(
|
||||
"Server Name",
|
||||
value="github",
|
||||
help="A unique name for this MCP server. Will be used in the toolgroup ID as mcp::{name}."
|
||||
)
|
||||
|
||||
# Servers Tab
|
||||
with tab1:
|
||||
st.header("Configured Servers")
|
||||
# MCP URL
|
||||
mcp_url = st.text_input(
|
||||
"MCP URL",
|
||||
value="https://api.githubcopilot.com/mcp/",
|
||||
help="The URL of the MCP server."
|
||||
)
|
||||
|
||||
# List existing servers
|
||||
servers = config.get("servers", {})
|
||||
if not servers:
|
||||
st.info("No MCP servers configured yet. Add a new server below.")
|
||||
# Get the current value from session state
|
||||
api_token = st.session_state.get(f"mcp_token_{server_name}", "")
|
||||
|
||||
# Server selection or creation
|
||||
server_options = list(servers.keys()) + ["+ Add New Server"]
|
||||
selected_server = st.selectbox("Select Server", server_options)
|
||||
# Input field for API Bearer Token
|
||||
mcp_token = st.text_input(
|
||||
"API Bearer Token",
|
||||
value=api_token,
|
||||
type="password",
|
||||
help="Enter your API Bearer Token. For GitHub Copilot, this should be a GitHub Personal Access Token with Copilot scope."
|
||||
)
|
||||
|
||||
if selected_server == "+ Add New Server":
|
||||
new_server_name = st.text_input("New Server Name")
|
||||
if new_server_name and st.button("Create Server"):
|
||||
if new_server_name in servers:
|
||||
st.error(f"Server '{new_server_name}' already exists.")
|
||||
else:
|
||||
servers[new_server_name] = {
|
||||
"type": "http",
|
||||
"url": "",
|
||||
"headers": {}
|
||||
}
|
||||
st.experimental_rerun()
|
||||
elif selected_server in servers:
|
||||
# Edit existing server
|
||||
updated_config = render_server_config(selected_server, servers[selected_server], config)
|
||||
|
||||
col1, col2 = st.columns([1, 1])
|
||||
if col1.button("Update Server", key=f"update_{selected_server}"):
|
||||
servers[selected_server] = updated_config
|
||||
save_config(config)
|
||||
# Submit button
|
||||
submit_button = st.form_submit_button("Save Configuration")
|
||||
|
||||
if col2.button("Delete Server", key=f"delete_{selected_server}"):
|
||||
del servers[selected_server]
|
||||
save_config(config)
|
||||
st.experimental_rerun()
|
||||
if submit_button:
|
||||
if not server_name:
|
||||
st.error("Server name is required.")
|
||||
elif not mcp_url:
|
||||
st.error("MCP URL is required.")
|
||||
else:
|
||||
# Store the token in session state
|
||||
st.session_state[f"mcp_token_{server_name}"] = mcp_token
|
||||
|
||||
# Inputs Tab
|
||||
with tab2:
|
||||
st.header("Input Configurations")
|
||||
try:
|
||||
# Register the MCP server as a toolgroup
|
||||
toolgroup_id = f"mcp::{server_name}"
|
||||
try:
|
||||
llama_stack_api.client.toolgroups.register(
|
||||
toolgroup_id=toolgroup_id,
|
||||
provider_id="model-context-protocol",
|
||||
mcp_endpoint=McpEndpoint(uri=mcp_url),
|
||||
timeout=10.0, # Set a reasonable timeout
|
||||
)
|
||||
except Exception as e:
|
||||
if "timeout" in str(e).lower():
|
||||
st.warning(f"Registration timed out, but configuration will still be saved. Error: {str(e)}")
|
||||
else:
|
||||
raise
|
||||
|
||||
inputs = config.get("inputs", [])
|
||||
if not inputs:
|
||||
st.info("No input configurations defined yet. Add a new input below.")
|
||||
# Update provider data with MCP headers
|
||||
# Check if provider_data attribute exists
|
||||
if not hasattr(llama_stack_api, "provider_data"):
|
||||
llama_stack_api.provider_data = {}
|
||||
|
||||
# Input selection or creation
|
||||
input_options = [f"{i.get('id', f'Input {idx}')} ({i.get('type', 'promptString')})" for idx, i in enumerate(inputs)]
|
||||
input_options.append("+ Add New Input")
|
||||
provider_data = llama_stack_api.provider_data.copy()
|
||||
if "mcp_headers" not in provider_data:
|
||||
provider_data["mcp_headers"] = {}
|
||||
|
||||
selected_input_option = st.selectbox("Select Input", input_options)
|
||||
# Add MCP headers
|
||||
if mcp_token:
|
||||
# Clean the token (remove 'Bearer ' prefix if present)
|
||||
clean_token = mcp_token
|
||||
if clean_token.lower().startswith("bearer "):
|
||||
clean_token = clean_token[7:]
|
||||
|
||||
if selected_input_option == "+ Add New Input":
|
||||
if st.button("Create Input"):
|
||||
inputs.append({
|
||||
"type": "promptString",
|
||||
"id": f"input_{len(inputs)}",
|
||||
"description": "",
|
||||
"password": False
|
||||
})
|
||||
save_config(config)
|
||||
st.experimental_rerun()
|
||||
else:
|
||||
# Edit existing input
|
||||
selected_idx = input_options.index(selected_input_option)
|
||||
if selected_idx < len(inputs):
|
||||
updated_input = render_input_config(inputs[selected_idx], selected_idx)
|
||||
# Set the headers for this MCP server
|
||||
# The key needs to be the exact URL used in the MCP endpoint
|
||||
provider_data["mcp_headers"][mcp_url] = {
|
||||
"Authorization": f"Bearer {clean_token}"
|
||||
}
|
||||
|
||||
col1, col2 = st.columns([1, 1])
|
||||
if col1.button("Update Input", key=f"update_input_{selected_idx}"):
|
||||
inputs[selected_idx] = updated_input
|
||||
save_config(config)
|
||||
# Debug information
|
||||
st.info(f"Set authentication headers for {mcp_url}: Bearer {clean_token[:4]}...")
|
||||
|
||||
if col2.button("Delete Input", key=f"delete_input_{selected_idx}"):
|
||||
inputs.pop(selected_idx)
|
||||
save_config(config)
|
||||
st.experimental_rerun()
|
||||
# Also set the token directly in the provider_data using different formats
|
||||
# This increases the chance that one of them will work
|
||||
provider_data[f"{server_name}_api_key"] = clean_token
|
||||
provider_data[f"{server_name}_token"] = clean_token
|
||||
provider_data[f"{server_name}_mcp_token"] = clean_token
|
||||
|
||||
# Input Values Tab
|
||||
with tab3:
|
||||
st.header("Input Values")
|
||||
# Display the current provider_data for debugging
|
||||
st.write("Current provider_data:")
|
||||
|
||||
inputs = config.get("inputs", [])
|
||||
if not inputs:
|
||||
st.info("No input configurations defined yet. Add inputs in the Inputs tab.")
|
||||
else:
|
||||
st.write("Enter values for the configured inputs:")
|
||||
# Mask the token for display
|
||||
masked_token = ""
|
||||
if clean_token:
|
||||
if len(clean_token) > 8:
|
||||
# Show first 2 and last 2 characters, mask the middle with asterisks
|
||||
masked_token = f"Bearer {clean_token[:2]}{'*' * 6}{clean_token[-2:]}"
|
||||
else:
|
||||
# For short tokens, just show first 2 chars and asterisks
|
||||
masked_token = f"Bearer {clean_token[:2]}{'*' * 4}"
|
||||
|
||||
for input_config in inputs:
|
||||
input_id = input_config.get("id", "")
|
||||
description = input_config.get("description", "")
|
||||
is_password = input_config.get("password", False)
|
||||
# Create a sanitized version of mcp_headers for display
|
||||
sanitized_headers = {}
|
||||
for url, headers in provider_data.get("mcp_headers", {}).items():
|
||||
sanitized_headers[url] = {}
|
||||
for header_key, header_value in headers.items():
|
||||
if header_key.lower() == "authorization" and isinstance(header_value, str):
|
||||
if header_value.lower().startswith("bearer "):
|
||||
token = header_value[7:]
|
||||
if len(token) > 8:
|
||||
sanitized_headers[url][header_key] = f"Bearer {token[:2]}{'*' * 6}{token[-2:]}"
|
||||
else:
|
||||
sanitized_headers[url][header_key] = f"Bearer {token[:2]}{'*' * 4}"
|
||||
else:
|
||||
sanitized_headers[url][header_key] = f"{header_value[:2]}{'*' * 6}"
|
||||
else:
|
||||
sanitized_headers[url][header_key] = header_value
|
||||
|
||||
# Get value from session state if available
|
||||
current_value = st.session_state.get(f"input_value_{input_id}", "")
|
||||
st.json({
|
||||
"mcp_headers": sanitized_headers,
|
||||
f"{server_name}_api_key": masked_token
|
||||
})
|
||||
|
||||
# Input field for value
|
||||
value = st.text_input(
|
||||
f"{description} ({input_id})",
|
||||
value=current_value,
|
||||
type="password" if is_password else "default",
|
||||
key=f"input_value_{input_id}"
|
||||
)
|
||||
# Add a note about authentication
|
||||
st.warning("""
|
||||
**Important Authentication Notes:**
|
||||
|
||||
if st.button("Save Input Values"):
|
||||
# Values are automatically saved to session state
|
||||
# Register MCP servers to apply the new values
|
||||
register_mcp_servers(config)
|
||||
st.success("Input values saved successfully!")
|
||||
1. If you encounter authentication errors when using this MCP server,
|
||||
try restarting the UI application to ensure the authentication headers are properly applied.
|
||||
|
||||
# JSON Config Tab
|
||||
with tab4:
|
||||
st.header("Raw JSON Configuration")
|
||||
2. For GitHub Copilot, make sure you're using a valid GitHub Personal Access Token with the 'copilot' scope.
|
||||
|
||||
# Display and edit raw JSON
|
||||
json_str = json.dumps(config, indent=2)
|
||||
edited_json = st.text_area("Edit JSON Configuration", json_str, height=400)
|
||||
3. When using the MCP server in code, you may need to explicitly pass the authentication headers:
|
||||
```python
|
||||
import json
|
||||
from llama_stack_client import Agent
|
||||
|
||||
if st.button("Update from JSON"):
|
||||
try:
|
||||
updated_config = json.loads(edited_json)
|
||||
save_config(updated_config)
|
||||
st.experimental_rerun()
|
||||
except json.JSONDecodeError as e:
|
||||
st.error(f"Invalid JSON: {str(e)}")
|
||||
agent = Agent(
|
||||
model="llama-3-70b-instruct",
|
||||
tools=["mcp::github"],
|
||||
extra_headers={
|
||||
"X-LlamaStack-Provider-Data": json.dumps({
|
||||
"mcp_headers": {
|
||||
"https://api.githubcopilot.com/mcp/": {
|
||||
"Authorization": "Bearer YOUR_TOKEN_HERE"
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
)
|
||||
```
|
||||
""")
|
||||
|
||||
# Example configuration and usage
|
||||
with st.expander("Example Configuration and Usage"):
|
||||
# Update the client with the new provider data
|
||||
if hasattr(llama_stack_api, "update_provider_data_dict"):
|
||||
llama_stack_api.update_provider_data_dict(provider_data)
|
||||
else:
|
||||
# Fallback implementation if method doesn't exist
|
||||
llama_stack_api.provider_data = provider_data
|
||||
# Reinitialize the client with updated provider data
|
||||
llama_stack_api.client = LlamaStackClient(
|
||||
base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"),
|
||||
provider_data=provider_data,
|
||||
)
|
||||
|
||||
st.success(f"MCP server '{server_name}' configured successfully!")
|
||||
except Exception as e:
|
||||
st.error(f"Failed to configure MCP server '{server_name}': {str(e)}")
|
||||
|
||||
# Usage example
|
||||
with st.expander("Usage Example"):
|
||||
st.markdown("""
|
||||
### Example Configuration
|
||||
### Using MCP Servers in Code
|
||||
|
||||
```json
|
||||
{
|
||||
"servers": {
|
||||
"github": {
|
||||
"type": "http",
|
||||
"url": "https://api.githubcopilot.com/mcp/",
|
||||
"headers": {
|
||||
"Authorization": "Bearer ${input:github_mcp_pat}"
|
||||
}
|
||||
}
|
||||
},
|
||||
"inputs": [
|
||||
{
|
||||
"type": "promptString",
|
||||
"id": "github_mcp_pat",
|
||||
"description": "GitHub Personal Access Token",
|
||||
"password": true
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Usage in Code
|
||||
|
||||
Once registered, you can use the MCP server in your code:
|
||||
Once configured, you can use the MCP server in your code:
|
||||
|
||||
```python
|
||||
from llama_stack_client import Agent
|
||||
|
||||
agent = Agent(
|
||||
model="llama-3-70b-instruct",
|
||||
tools=["mcp::github"], # Use the registered MCP server
|
||||
tools=["mcp::your_server_name"], # Use the registered MCP server
|
||||
)
|
||||
|
||||
agent.create_turn("Use GitHub Copilot to help me write a function")
|
||||
agent.create_turn("Use the MCP server to help me with a task")
|
||||
```
|
||||
|
||||
### With Explicit Authentication Headers
|
||||
|
||||
If you encounter authentication issues, you can explicitly pass the authentication headers:
|
||||
|
||||
```python
|
||||
import json
|
||||
from llama_stack_client import Agent
|
||||
|
||||
agent = Agent(
|
||||
model="llama-3-70b-instruct",
|
||||
tools=["mcp::github"],
|
||||
extra_headers={
|
||||
"X-LlamaStack-Provider-Data": json.dumps({
|
||||
"mcp_headers": {
|
||||
"https://api.githubcopilot.com/mcp/": {
|
||||
"Authorization": "Bearer YOUR_TOKEN_HERE"
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
)
|
||||
```
|
||||
""")
|
||||
|
||||
# Display registered toolgroups
|
||||
st.header("Registered MCP Toolgroups")
|
||||
try:
|
||||
toolgroups = llama_stack_api.client.toolgroups.list(timeout=5.0) # Set a reasonable timeout
|
||||
|
||||
# Debug the structure of the toolgroups
|
||||
if toolgroups:
|
||||
# Check the first toolgroup to determine its structure
|
||||
first_toolgroup = toolgroups[0] if toolgroups else None
|
||||
|
||||
if first_toolgroup:
|
||||
# Get the attribute names
|
||||
attr_names = dir(first_toolgroup)
|
||||
|
||||
# The ToolGroup class has an 'identifier' attribute as per the API definition
|
||||
id_attr = 'identifier'
|
||||
|
||||
# Filter MCP toolgroups based on the identifier attribute
|
||||
mcp_toolgroups = []
|
||||
for tg in toolgroups:
|
||||
if hasattr(tg, 'identifier') and isinstance(tg.identifier, str) and tg.identifier.startswith("mcp::"):
|
||||
mcp_toolgroups.append(tg)
|
||||
|
||||
if mcp_toolgroups:
|
||||
# Display MCP servers with unregister buttons
|
||||
st.write("Click the button next to a server to unregister it:")
|
||||
|
||||
for tg in mcp_toolgroups:
|
||||
col1, col2, col3, col4 = st.columns([3, 2, 3, 1])
|
||||
|
||||
with col1:
|
||||
st.write(f"**{tg.identifier}**")
|
||||
|
||||
with col2:
|
||||
st.write(tg.provider_id)
|
||||
|
||||
with col3:
|
||||
if hasattr(tg, 'mcp_endpoint') and tg.mcp_endpoint:
|
||||
st.write(tg.mcp_endpoint.uri)
|
||||
else:
|
||||
st.write("N/A")
|
||||
|
||||
with col4:
|
||||
# Extract server name from identifier (remove "mcp::" prefix)
|
||||
server_name = tg.identifier.replace("mcp::", "")
|
||||
if st.button("Unregister", key=f"unregister_{server_name}"):
|
||||
try:
|
||||
# Call the unregister API
|
||||
llama_stack_api.client.toolgroups.unregister(
|
||||
toolgroup_id=tg.identifier,
|
||||
timeout=5.0
|
||||
)
|
||||
|
||||
# Also clean up provider data
|
||||
if hasattr(llama_stack_api, "provider_data"):
|
||||
provider_data = llama_stack_api.provider_data.copy()
|
||||
|
||||
# Remove MCP headers for this server if they exist
|
||||
if "mcp_headers" in provider_data and hasattr(tg, 'mcp_endpoint') and tg.mcp_endpoint:
|
||||
if tg.mcp_endpoint.uri in provider_data["mcp_headers"]:
|
||||
del provider_data["mcp_headers"][tg.mcp_endpoint.uri]
|
||||
|
||||
# Remove server-specific tokens
|
||||
keys_to_remove = [
|
||||
f"{server_name}_api_key",
|
||||
f"{server_name}_token",
|
||||
f"{server_name}_mcp_token"
|
||||
]
|
||||
for key in keys_to_remove:
|
||||
if key in provider_data:
|
||||
del provider_data[key]
|
||||
|
||||
# Update the client with the modified provider data
|
||||
if hasattr(llama_stack_api, "update_provider_data_dict"):
|
||||
llama_stack_api.update_provider_data_dict(provider_data)
|
||||
else:
|
||||
# Fallback implementation
|
||||
llama_stack_api.provider_data = provider_data
|
||||
llama_stack_api.client = LlamaStackClient(
|
||||
base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"),
|
||||
provider_data=provider_data,
|
||||
)
|
||||
|
||||
st.success(f"Successfully unregistered {tg.identifier}")
|
||||
st.rerun() # Refresh the page to update the list
|
||||
except Exception as e:
|
||||
st.error(f"Failed to unregister {tg.identifier}: {str(e)}")
|
||||
else:
|
||||
st.info("No MCP toolgroups registered yet.")
|
||||
else:
|
||||
st.info("No toolgroups found.")
|
||||
else:
|
||||
st.info("No toolgroups returned from the API.")
|
||||
except Exception as e:
|
||||
if "timeout" in str(e).lower():
|
||||
st.warning("Listing toolgroups timed out. The server might be busy or unreachable.")
|
||||
else:
|
||||
st.error(f"Failed to list toolgroups: {str(e)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -80,9 +80,14 @@ def tool_chat_page():
|
|||
total_tools = 0
|
||||
|
||||
for toolgroup_id in toolgroup_selection:
|
||||
tools = client.tools.list(toolgroup_id=toolgroup_id)
|
||||
grouped_tools[toolgroup_id] = [tool.identifier for tool in tools]
|
||||
total_tools += len(tools)
|
||||
try:
|
||||
# Add a timeout of 5 seconds to prevent UI freezing
|
||||
tools = client.tools.list(toolgroup_id=toolgroup_id, timeout=5.0)
|
||||
grouped_tools[toolgroup_id] = [tool.identifier for tool in tools]
|
||||
total_tools += len(tools)
|
||||
except Exception as e:
|
||||
st.warning(f"Failed to list tools for {toolgroup_id}: {str(e)}")
|
||||
grouped_tools[toolgroup_id] = []
|
||||
|
||||
st.markdown(f"Active Tools: 🛠 {total_tools}")
|
||||
|
||||
|
@ -126,26 +131,57 @@ def tool_chat_page():
|
|||
|
||||
@st.cache_resource
|
||||
def create_agent():
|
||||
if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT:
|
||||
return ReActAgent(
|
||||
client=client,
|
||||
model=model,
|
||||
tools=toolgroup_selection,
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": ReActOutput.model_json_schema(),
|
||||
},
|
||||
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
||||
)
|
||||
else:
|
||||
return Agent(
|
||||
client,
|
||||
model=model,
|
||||
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
|
||||
tools=toolgroup_selection,
|
||||
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
||||
try:
|
||||
if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT:
|
||||
return ReActAgent(
|
||||
client=client,
|
||||
model=model,
|
||||
tools=toolgroup_selection,
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": ReActOutput.model_json_schema(),
|
||||
},
|
||||
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
||||
)
|
||||
else:
|
||||
return Agent(
|
||||
client,
|
||||
model=model,
|
||||
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
|
||||
tools=toolgroup_selection,
|
||||
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
||||
)
|
||||
except Exception as e:
|
||||
# Log the error
|
||||
st.error(f"Failed to create agent: {str(e)}")
|
||||
|
||||
# Create a fallback agent without tools
|
||||
st.warning(
|
||||
"Creating a fallback agent without tools due to an error. "
|
||||
"Some functionality may be limited. Try refreshing the page or selecting different tools."
|
||||
)
|
||||
|
||||
# Return a basic agent without tools
|
||||
if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT:
|
||||
return ReActAgent(
|
||||
client=client,
|
||||
model=model,
|
||||
tools=[], # No tools
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": ReActOutput.model_json_schema(),
|
||||
},
|
||||
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
||||
)
|
||||
else:
|
||||
return Agent(
|
||||
client,
|
||||
model=model,
|
||||
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
|
||||
tools=[], # No tools
|
||||
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
||||
)
|
||||
|
||||
st.session_state.agent_type = agent_type
|
||||
|
||||
agent = create_agent()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue