diff --git a/llama_stack/distribution/ui/page/playground/rag.py b/llama_stack/distribution/ui/page/playground/rag.py index bb31bd2a7..be222f840 100644 --- a/llama_stack/distribution/ui/page/playground/rag.py +++ b/llama_stack/distribution/ui/page/playground/rag.py @@ -16,6 +16,13 @@ from llama_stack.distribution.ui.modules.utils import data_url_from_file def rag_chat_page(): st.title("🦙 RAG") + def reset_agent_and_chat(): + st.session_state.clear() + st.cache_resource.clear() + + def should_disable_input(): + return "messages" in st.session_state and len(st.session_state.messages) > 0 + with st.sidebar: # File/Directory Upload Section st.subheader("Upload Documents") @@ -69,21 +76,27 @@ def rag_chat_page(): vector_dbs = llama_stack_api.client.vector_dbs.list() vector_dbs = [vector_db.identifier for vector_db in vector_dbs] selected_vector_dbs = st.multiselect( - "Select Vector Databases", - vector_dbs, + label="Select Vector Databases", + options=vector_dbs, + on_change=reset_agent_and_chat, + disabled=should_disable_input(), ) available_models = llama_stack_api.client.models.list() available_models = [model.identifier for model in available_models if model.model_type == "llm"] selected_model = st.selectbox( - "Choose a model", - available_models, + label="Choose a model", + options=available_models, index=0, + on_change=reset_agent_and_chat, + disabled=should_disable_input(), ) system_prompt = st.text_area( "System Prompt", value="You are a helpful assistant. ", help="Initial instructions given to the AI to set its behavior and context", + on_change=reset_agent_and_chat, + disabled=should_disable_input(), ) temperature = st.slider( "Temperature", @@ -92,6 +105,8 @@ def rag_chat_page(): value=0.0, step=0.1, help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable", + on_change=reset_agent_and_chat, + disabled=should_disable_input(), ) top_p = st.slider( @@ -100,12 +115,14 @@ def rag_chat_page(): max_value=1.0, value=0.95, step=0.1, + on_change=reset_agent_and_chat, + disabled=should_disable_input(), ) # Add clear chat button to sidebar if st.button("Clear Chat", use_container_width=True): - st.session_state.clear() - st.cache_resource.clear() + reset_agent_and_chat() + st.rerun() # Chat Interface if "messages" not in st.session_state: @@ -151,15 +168,8 @@ def rag_chat_page(): session_id = st.session_state["agent_session_id"] - # Chat input - if prompt := st.chat_input("Ask a question about your documents"): - # Add user message to chat history - st.session_state.messages.append({"role": "user", "content": prompt}) - - # Display user message - with st.chat_message("user"): - st.markdown(prompt) - + def process_prompt(prompt): + # Send the prompt to the agent response = agent.create_turn( messages=[ { @@ -188,5 +198,24 @@ def rag_chat_page(): st.session_state.messages.append({"role": "assistant", "content": full_response}) + # Chat input + if prompt := st.chat_input("Ask a question about your documents"): + # Add user message to chat history + st.session_state.messages.append({"role": "user", "content": prompt}) + + # Display user message + with st.chat_message("user"): + st.markdown(prompt) + + # store the prompt to process it after page refresh + st.session_state.prompt = prompt + + # force page refresh to disable the settings widgets + st.rerun() + + if "prompt" in st.session_state and st.session_state.prompt is not None: + process_prompt(st.session_state.prompt) + st.session_state.prompt = None + rag_chat_page()