From c6e93e32f62cbe8e67bc56e0822fc2ff1b04f48b Mon Sep 17 00:00:00 2001 From: Michael Clifford Date: Tue, 8 Apr 2025 03:46:13 -0400 Subject: [PATCH] feat: Updated playground rag to use session id for persistent conversation (#1870) # What does this PR do? This PR updates the [playground RAG example](llama_stack/distribution/ui/page/playground/rag.py) so that the agent is able to use its builtin conversation history. Here we are using streamlit's `cache_resource` functionality to prevent the agent from re-initializing after every interaction as well as storing its session_id in the `session_state`. This allows the agent in the RAG example to behave more closely to how it works using the python-client directly. [//]: # (If resolving an issue, uncomment and update the line below) Closes #1869 ## Test Plan Without these changes, if you ask it "What is 2 + 2"? followed by the question "What did I just ask?" It will provide an obviously incorrect answer. With these changes, you can ask the same series of questions and it will provide the correct answer. [//]: # (## Documentation) Signed-off-by: Michael Clifford --- .../distribution/ui/page/playground/rag.py | 48 +++++++++++-------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/llama_stack/distribution/ui/page/playground/rag.py b/llama_stack/distribution/ui/page/playground/rag.py index fcd0f908e..bb31bd2a7 100644 --- a/llama_stack/distribution/ui/page/playground/rag.py +++ b/llama_stack/distribution/ui/page/playground/rag.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import uuid + import streamlit as st from llama_stack_client import Agent, AgentEventLogger, RAGDocument @@ -102,8 +104,8 @@ def rag_chat_page(): # Add clear chat button to sidebar if st.button("Clear Chat", use_container_width=True): - st.session_state.messages = [] - st.rerun() + st.session_state.clear() + st.cache_resource.clear() # Chat Interface if "messages" not in st.session_state: @@ -123,23 +125,31 @@ def rag_chat_page(): else: strategy = {"type": "greedy"} - agent = Agent( - llama_stack_api.client, - model=selected_model, - instructions=system_prompt, - sampling_params={ - "strategy": strategy, - }, - tools=[ - dict( - name="builtin::rag/knowledge_search", - args={ - "vector_db_ids": list(selected_vector_dbs), - }, - ) - ], - ) - session_id = agent.create_session("rag-session") + @st.cache_resource + def create_agent(): + return Agent( + llama_stack_api.client, + model=selected_model, + instructions=system_prompt, + sampling_params={ + "strategy": strategy, + }, + tools=[ + dict( + name="builtin::rag/knowledge_search", + args={ + "vector_db_ids": list(selected_vector_dbs), + }, + ) + ], + ) + + agent = create_agent() + + if "agent_session_id" not in st.session_state: + st.session_state["agent_session_id"] = agent.create_session(session_name=f"rag_demo_{uuid.uuid4()}") + + session_id = st.session_state["agent_session_id"] # Chat input if prompt := st.chat_input("Ask a question about your documents"):