forked from phoenix-oss/llama-stack-mirror
# What does this PR do? This PR lets users select an existing vdb to use with their agent on the tools page of the playground. The drop down menu that lets users select a vdb only appears when the rag tool is selected. Without this change, there is no way for a user to specify which vdb they want their rag tool to use on the tools page. I have intentionally left the RAG options sparse here since the full RAG options are exposed on the RAG page. ## Test Plan Without these changes the RAG tool will throw the following error: `name: knowledge_search) does not have any content ` With these changes the RAG tool works as expected. Signed-off-by: Michael Clifford <mcliffor@redhat.com>
146 lines
5.2 KiB
Python
146 lines
5.2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import uuid
|
|
|
|
import streamlit as st
|
|
from llama_stack_client import Agent
|
|
|
|
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
|
|
|
|
|
def tool_chat_page():
|
|
st.title("🛠 Tools")
|
|
|
|
client = llama_stack_api.client
|
|
models = client.models.list()
|
|
model_list = [model.identifier for model in models if model.api_model_type == "llm"]
|
|
|
|
tool_groups = client.toolgroups.list()
|
|
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
|
|
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
|
|
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
|
|
|
|
def reset_agent():
|
|
st.session_state.clear()
|
|
st.cache_resource.clear()
|
|
|
|
with st.sidebar:
|
|
st.subheader("Model")
|
|
model = st.selectbox(label="models", options=model_list, on_change=reset_agent)
|
|
|
|
st.subheader("Builtin Tools")
|
|
toolgroup_selection = st.pills(
|
|
label="Available ToolGroups", options=builtin_tools_list, selection_mode="multi", on_change=reset_agent
|
|
)
|
|
|
|
if "builtin::rag" in toolgroup_selection:
|
|
vector_dbs = llama_stack_api.client.vector_dbs.list() or []
|
|
if not vector_dbs:
|
|
st.info("No vector databases available for selection.")
|
|
vector_dbs = [vector_db.identifier for vector_db in vector_dbs]
|
|
selected_vector_dbs = st.multiselect(
|
|
label="Select Document Collections to use in RAG queries",
|
|
options=vector_dbs,
|
|
on_change=reset_agent,
|
|
)
|
|
|
|
st.subheader("MCP Servers")
|
|
mcp_selection = st.pills(
|
|
label="Available MCP Servers", options=mcp_tools_list, selection_mode="multi", on_change=reset_agent
|
|
)
|
|
|
|
toolgroup_selection.extend(mcp_selection)
|
|
|
|
active_tool_list = []
|
|
for toolgroup_id in toolgroup_selection:
|
|
active_tool_list.extend(
|
|
[
|
|
f"{''.join(toolgroup_id.split('::')[1:])}:{t.identifier}"
|
|
for t in client.tools.list(toolgroup_id=toolgroup_id)
|
|
]
|
|
)
|
|
|
|
st.subheader(f"Active Tools: 🛠 {len(active_tool_list)}")
|
|
st.json(active_tool_list)
|
|
|
|
st.subheader("Chat Configurations")
|
|
max_tokens = st.slider(
|
|
"Max Tokens",
|
|
min_value=0,
|
|
max_value=4096,
|
|
value=512,
|
|
step=1,
|
|
help="The maximum number of tokens to generate",
|
|
on_change=reset_agent,
|
|
)
|
|
|
|
for i, tool_name in enumerate(toolgroup_selection):
|
|
if tool_name == "builtin::rag":
|
|
tool_dict = dict(
|
|
name="builtin::rag",
|
|
args={
|
|
"vector_db_ids": list(selected_vector_dbs),
|
|
},
|
|
)
|
|
toolgroup_selection[i] = tool_dict
|
|
|
|
@st.cache_resource
|
|
def create_agent():
|
|
return Agent(
|
|
client,
|
|
model=model,
|
|
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
|
|
tools=toolgroup_selection,
|
|
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
|
)
|
|
|
|
agent = create_agent()
|
|
|
|
if "agent_session_id" not in st.session_state:
|
|
st.session_state["agent_session_id"] = agent.create_session(session_name=f"tool_demo_{uuid.uuid4()}")
|
|
|
|
session_id = st.session_state["agent_session_id"]
|
|
|
|
if "messages" not in st.session_state:
|
|
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
|
|
|
for msg in st.session_state.messages:
|
|
with st.chat_message(msg["role"]):
|
|
st.markdown(msg["content"])
|
|
|
|
if prompt := st.chat_input(placeholder=""):
|
|
with st.chat_message("user"):
|
|
st.markdown(prompt)
|
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt})
|
|
|
|
turn_response = agent.create_turn(
|
|
session_id=session_id,
|
|
messages=[{"role": "user", "content": prompt}],
|
|
stream=True,
|
|
)
|
|
|
|
def response_generator(turn_response):
|
|
for response in turn_response:
|
|
if hasattr(response.event, "payload"):
|
|
print(response.event.payload)
|
|
if response.event.payload.event_type == "step_progress":
|
|
if hasattr(response.event.payload.delta, "text"):
|
|
yield response.event.payload.delta.text
|
|
if response.event.payload.event_type == "step_complete":
|
|
if response.event.payload.step_details.step_type == "tool_execution":
|
|
yield " 🛠 "
|
|
else:
|
|
yield f"Error occurred in the Llama Stack Cluster: {response}"
|
|
|
|
with st.chat_message("assistant"):
|
|
response = st.write_stream(response_generator(turn_response))
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": response})
|
|
|
|
|
|
tool_chat_page()
|