diff --git a/llama_stack/distribution/ui/page/playground/rag.py b/llama_stack/distribution/ui/page/playground/rag.py index fcd0f908e..bb31bd2a7 100644 --- a/llama_stack/distribution/ui/page/playground/rag.py +++ b/llama_stack/distribution/ui/page/playground/rag.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import uuid + import streamlit as st from llama_stack_client import Agent, AgentEventLogger, RAGDocument @@ -102,8 +104,8 @@ def rag_chat_page(): # Add clear chat button to sidebar if st.button("Clear Chat", use_container_width=True): - st.session_state.messages = [] - st.rerun() + st.session_state.clear() + st.cache_resource.clear() # Chat Interface if "messages" not in st.session_state: @@ -123,23 +125,31 @@ def rag_chat_page(): else: strategy = {"type": "greedy"} - agent = Agent( - llama_stack_api.client, - model=selected_model, - instructions=system_prompt, - sampling_params={ - "strategy": strategy, - }, - tools=[ - dict( - name="builtin::rag/knowledge_search", - args={ - "vector_db_ids": list(selected_vector_dbs), - }, - ) - ], - ) - session_id = agent.create_session("rag-session") + @st.cache_resource + def create_agent(): + return Agent( + llama_stack_api.client, + model=selected_model, + instructions=system_prompt, + sampling_params={ + "strategy": strategy, + }, + tools=[ + dict( + name="builtin::rag/knowledge_search", + args={ + "vector_db_ids": list(selected_vector_dbs), + }, + ) + ], + ) + + agent = create_agent() + + if "agent_session_id" not in st.session_state: + st.session_state["agent_session_id"] = agent.create_session(session_name=f"rag_demo_{uuid.uuid4()}") + + session_id = st.session_state["agent_session_id"] # Chat input if prompt := st.chat_input("Ask a question about your documents"):