fix: Playground RAG page errors (#1928)

# What does this PR do?
This PR fixes two issues with the RAG page of the Playground UI:

1. When the user modifies a configurable setting via a widget (e.g.,
system prompt, temperature, etc.), the agent is not recreated. Thus, the
change has no effect and the user gets no indication of that.
2. After the first issue is fixed, it becomes possible to recreate the
agent mid-conversation or even mid-generation. To mitigate this, widgets
related to agent configuration are now disabled when a conversation is
in progress (i.e., when the chat is non-empty). They are automatically
enabled again when the user resets the chat history.

## Test Plan

- Launch the Playground and go to the RAG page;
- Select the vector DB ID;
- Send a message to the agent via the chat;
- The widgets in charge of the agent parameters will become disabled at
this point;
- Send a second message asking the model about the content of the first
message;
- The reply will indicate that the two messages were sent over the same
session, that is, the agent was not recreated;
- Click the 'Clear Chat' button;
- All widgets will be enabled and a new agent will be created (which can
be validated by sending another message).
This commit is contained in:
Ilya Kolchinsky 2025-04-10 22:38:31 +02:00 committed by GitHub
parent de6ec5803e
commit 79fc81f78f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -16,6 +16,13 @@ from llama_stack.distribution.ui.modules.utils import data_url_from_file
def rag_chat_page(): def rag_chat_page():
st.title("🦙 RAG") st.title("🦙 RAG")
def reset_agent_and_chat():
st.session_state.clear()
st.cache_resource.clear()
def should_disable_input():
return "messages" in st.session_state and len(st.session_state.messages) > 0
with st.sidebar: with st.sidebar:
# File/Directory Upload Section # File/Directory Upload Section
st.subheader("Upload Documents") st.subheader("Upload Documents")
@ -69,21 +76,27 @@ def rag_chat_page():
vector_dbs = llama_stack_api.client.vector_dbs.list() vector_dbs = llama_stack_api.client.vector_dbs.list()
vector_dbs = [vector_db.identifier for vector_db in vector_dbs] vector_dbs = [vector_db.identifier for vector_db in vector_dbs]
selected_vector_dbs = st.multiselect( selected_vector_dbs = st.multiselect(
"Select Vector Databases", label="Select Vector Databases",
vector_dbs, options=vector_dbs,
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
) )
available_models = llama_stack_api.client.models.list() available_models = llama_stack_api.client.models.list()
available_models = [model.identifier for model in available_models if model.model_type == "llm"] available_models = [model.identifier for model in available_models if model.model_type == "llm"]
selected_model = st.selectbox( selected_model = st.selectbox(
"Choose a model", label="Choose a model",
available_models, options=available_models,
index=0, index=0,
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
) )
system_prompt = st.text_area( system_prompt = st.text_area(
"System Prompt", "System Prompt",
value="You are a helpful assistant. ", value="You are a helpful assistant. ",
help="Initial instructions given to the AI to set its behavior and context", help="Initial instructions given to the AI to set its behavior and context",
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
) )
temperature = st.slider( temperature = st.slider(
"Temperature", "Temperature",
@ -92,6 +105,8 @@ def rag_chat_page():
value=0.0, value=0.0,
step=0.1, step=0.1,
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable", help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
) )
top_p = st.slider( top_p = st.slider(
@ -100,12 +115,14 @@ def rag_chat_page():
max_value=1.0, max_value=1.0,
value=0.95, value=0.95,
step=0.1, step=0.1,
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
) )
# Add clear chat button to sidebar # Add clear chat button to sidebar
if st.button("Clear Chat", use_container_width=True): if st.button("Clear Chat", use_container_width=True):
st.session_state.clear() reset_agent_and_chat()
st.cache_resource.clear() st.rerun()
# Chat Interface # Chat Interface
if "messages" not in st.session_state: if "messages" not in st.session_state:
@ -151,15 +168,8 @@ def rag_chat_page():
session_id = st.session_state["agent_session_id"] session_id = st.session_state["agent_session_id"]
# Chat input def process_prompt(prompt):
if prompt := st.chat_input("Ask a question about your documents"): # Send the prompt to the agent
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
response = agent.create_turn( response = agent.create_turn(
messages=[ messages=[
{ {
@ -188,5 +198,24 @@ def rag_chat_page():
st.session_state.messages.append({"role": "assistant", "content": full_response}) st.session_state.messages.append({"role": "assistant", "content": full_response})
# Chat input
if prompt := st.chat_input("Ask a question about your documents"):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# store the prompt to process it after page refresh
st.session_state.prompt = prompt
# force page refresh to disable the settings widgets
st.rerun()
if "prompt" in st.session_state and st.session_state.prompt is not None:
process_prompt(st.session_state.prompt)
st.session_state.prompt = None
rag_chat_page() rag_chat_page()