forked from phoenix-oss/llama-stack-mirror
# What does this PR do? This PR improves the Tools page in the LlamaStack Playground UI by enhancing the readability of the active tool list shown in the sidebar. - Previously, active tools were displayed in a flat JSON array with verbose identifiers (e.g., builtin::code_interpreter:code_interpreter). - This PR updates the logic to group tools by their toolgroup (e.g., builtin::websearch) and renders each tool name in a simplified, human-readable format (e.g., web_search). - This change improves usability when working with multiple toolgroups, especially in configurations involving MCP tools or complex tool identifiers. Before and After Comparison: **Before**  **After**  [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan - Followed the [LlamaStack UI Developer Setup instructions](https://github.com/meta-llama/llama-stack/tree/main/llama_stack/distribution/ui) - Ran the Streamlit UI via: `uv run --with "[.ui]" streamlit run llama_stack/distribution/ui/app.py` - Selected multiple built-in toolgroups (e.g., code_interpreter, websearch, wolfram_alpha) from the sidebar. [//]: # (## Documentation)
164 lines
6.1 KiB
Python
164 lines
6.1 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import uuid
|
|
|
|
import streamlit as st
|
|
from llama_stack_client import Agent
|
|
|
|
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
|
|
|
|
|
def tool_chat_page():
|
|
st.title("🛠 Tools")
|
|
|
|
client = llama_stack_api.client
|
|
models = client.models.list()
|
|
model_list = [model.identifier for model in models if model.api_model_type == "llm"]
|
|
|
|
tool_groups = client.toolgroups.list()
|
|
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
|
|
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
|
|
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
|
|
|
|
def reset_agent():
|
|
st.session_state.clear()
|
|
st.cache_resource.clear()
|
|
|
|
with st.sidebar:
|
|
st.title("Configuration")
|
|
st.subheader("Model")
|
|
model = st.selectbox(label="Model", options=model_list, on_change=reset_agent, label_visibility="collapsed")
|
|
|
|
st.subheader("Available ToolGroups")
|
|
|
|
toolgroup_selection = st.pills(
|
|
label="Built-in tools",
|
|
options=builtin_tools_list,
|
|
selection_mode="multi",
|
|
on_change=reset_agent,
|
|
format_func=lambda tool: "".join(tool.split("::")[1:]),
|
|
help="List of built-in tools from your llama stack server.",
|
|
)
|
|
|
|
if "builtin::rag" in toolgroup_selection:
|
|
vector_dbs = llama_stack_api.client.vector_dbs.list() or []
|
|
if not vector_dbs:
|
|
st.info("No vector databases available for selection.")
|
|
vector_dbs = [vector_db.identifier for vector_db in vector_dbs]
|
|
selected_vector_dbs = st.multiselect(
|
|
label="Select Document Collections to use in RAG queries",
|
|
options=vector_dbs,
|
|
on_change=reset_agent,
|
|
)
|
|
|
|
mcp_selection = st.pills(
|
|
label="MCP Servers",
|
|
options=mcp_tools_list,
|
|
selection_mode="multi",
|
|
on_change=reset_agent,
|
|
format_func=lambda tool: "".join(tool.split("::")[1:]),
|
|
help="List of MCP servers registered to your llama stack server.",
|
|
)
|
|
|
|
toolgroup_selection.extend(mcp_selection)
|
|
|
|
grouped_tools = {}
|
|
total_tools = 0
|
|
|
|
for toolgroup_id in toolgroup_selection:
|
|
tools = client.tools.list(toolgroup_id=toolgroup_id)
|
|
grouped_tools[toolgroup_id] = [tool.identifier for tool in tools]
|
|
total_tools += len(tools)
|
|
|
|
st.markdown(f"Active Tools: 🛠 {total_tools}")
|
|
|
|
for group_id, tools in grouped_tools.items():
|
|
with st.expander(f"🔧 Tools from `{group_id}`"):
|
|
for idx, tool in enumerate(tools, start=1):
|
|
st.markdown(f"{idx}. `{tool.split(':')[-1]}`")
|
|
|
|
st.subheader("Agent Configurations")
|
|
max_tokens = st.slider(
|
|
"Max Tokens",
|
|
min_value=0,
|
|
max_value=4096,
|
|
value=512,
|
|
step=1,
|
|
help="The maximum number of tokens to generate",
|
|
on_change=reset_agent,
|
|
)
|
|
|
|
for i, tool_name in enumerate(toolgroup_selection):
|
|
if tool_name == "builtin::rag":
|
|
tool_dict = dict(
|
|
name="builtin::rag",
|
|
args={
|
|
"vector_db_ids": list(selected_vector_dbs),
|
|
},
|
|
)
|
|
toolgroup_selection[i] = tool_dict
|
|
|
|
@st.cache_resource
|
|
def create_agent():
|
|
return Agent(
|
|
client,
|
|
model=model,
|
|
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
|
|
tools=toolgroup_selection,
|
|
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
|
)
|
|
|
|
agent = create_agent()
|
|
|
|
if "agent_session_id" not in st.session_state:
|
|
st.session_state["agent_session_id"] = agent.create_session(session_name=f"tool_demo_{uuid.uuid4()}")
|
|
|
|
session_id = st.session_state["agent_session_id"]
|
|
|
|
if "messages" not in st.session_state:
|
|
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
|
|
|
for msg in st.session_state.messages:
|
|
with st.chat_message(msg["role"]):
|
|
st.markdown(msg["content"])
|
|
|
|
if prompt := st.chat_input(placeholder=""):
|
|
with st.chat_message("user"):
|
|
st.markdown(prompt)
|
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt})
|
|
|
|
turn_response = agent.create_turn(
|
|
session_id=session_id,
|
|
messages=[{"role": "user", "content": prompt}],
|
|
stream=True,
|
|
)
|
|
|
|
def response_generator(turn_response):
|
|
for response in turn_response:
|
|
if hasattr(response.event, "payload"):
|
|
print(response.event.payload)
|
|
if response.event.payload.event_type == "step_progress":
|
|
if hasattr(response.event.payload.delta, "text"):
|
|
yield response.event.payload.delta.text
|
|
if response.event.payload.event_type == "step_complete":
|
|
if response.event.payload.step_details.step_type == "tool_execution":
|
|
if response.event.payload.step_details.tool_calls:
|
|
tool_name = str(response.event.payload.step_details.tool_calls[0].tool_name)
|
|
yield f'\n\n🛠 :grey[_Using "{tool_name}" tool:_]\n\n'
|
|
else:
|
|
yield "No tool_calls present in step_details"
|
|
else:
|
|
yield f"Error occurred in the Llama Stack Cluster: {response}"
|
|
|
|
with st.chat_message("assistant"):
|
|
response = st.write_stream(response_generator(turn_response))
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": response})
|
|
|
|
|
|
tool_chat_page()
|