From 40f41af2f74078028f0d79ecc291722884679d1c Mon Sep 17 00:00:00 2001 From: Ilya Kolchinsky <58424190+ilya-kolchinsky@users.noreply.github.com> Date: Fri, 11 Apr 2025 19:16:10 +0200 Subject: [PATCH] feat: Add a direct (non-agentic) RAG option to the Playground RAG page (#1940) # What does this PR do? This PR makes it possible to switch between agentic and non-agentic RAG when running the respective Playground page. When non-agentic RAG is selected, user queries are answered by directly querying the vector DB, augmenting the prompt, and sending the extended prompt to the model via Inference API. ## Test Plan - Launch the Playground and go to the RAG page; - Select the vector DB ID; - Adjust other configuration parameters if necessary; - Set the radio button to Agent-based RAG; - Send a message to the chat; - The query will be answered by an agent using the knowledge search tool as indicated by the output; - Click the 'Clear Chat' button to make it possible to switch modes; - Send a message to the chat again; - This time, the query will be answered by the model directly as can be deduced from the reply. --- .../distribution/ui/page/playground/rag.py | 103 +++++++++++++++--- 1 file changed, 88 insertions(+), 15 deletions(-) diff --git a/llama_stack/distribution/ui/page/playground/rag.py b/llama_stack/distribution/ui/page/playground/rag.py index be222f840..392c9afe2 100644 --- a/llama_stack/distribution/ui/page/playground/rag.py +++ b/llama_stack/distribution/ui/page/playground/rag.py @@ -9,6 +9,7 @@ import uuid import streamlit as st from llama_stack_client import Agent, AgentEventLogger, RAGDocument +from llama_stack.apis.common.content_types import ToolCallDelta from llama_stack.distribution.ui.modules.api import llama_stack_api from llama_stack.distribution.ui.modules.utils import data_url_from_file @@ -21,11 +22,11 @@ def rag_chat_page(): st.cache_resource.clear() def should_disable_input(): - return "messages" in st.session_state and len(st.session_state.messages) > 0 + return "displayed_messages" in st.session_state and len(st.session_state.displayed_messages) > 0 with st.sidebar: # File/Directory Upload Section - st.subheader("Upload Documents") + st.subheader("Upload Documents", divider=True) uploaded_files = st.file_uploader( "Upload file(s) or directory", accept_multiple_files=True, @@ -36,11 +37,11 @@ def rag_chat_page(): st.success(f"Successfully uploaded {len(uploaded_files)} files") # Add memory bank name input field vector_db_name = st.text_input( - "Vector Database Name", + "Document Collection Name", value="rag_vector_db", - help="Enter a unique identifier for this vector database", + help="Enter a unique identifier for this document collection", ) - if st.button("Create Vector Database"): + if st.button("Create Document Collection"): documents = [ RAGDocument( document_id=uploaded_file.name, @@ -71,17 +72,30 @@ def rag_chat_page(): ) st.success("Vector database created successfully!") - st.subheader("Configure Agent") + st.subheader("RAG Parameters", divider=True) + + rag_mode = st.radio( + "RAG mode", + ["Direct", "Agent-based"], + captions=[ + "RAG is performed by directly retrieving the information and augmenting the user query", + "RAG is performed by an agent activating a dedicated knowledge search tool.", + ], + on_change=reset_agent_and_chat, + disabled=should_disable_input(), + ) + # select memory banks vector_dbs = llama_stack_api.client.vector_dbs.list() vector_dbs = [vector_db.identifier for vector_db in vector_dbs] selected_vector_dbs = st.multiselect( - label="Select Vector Databases", + label="Select Document Collections to use in RAG queries", options=vector_dbs, on_change=reset_agent_and_chat, disabled=should_disable_input(), ) + st.subheader("Inference Parameters", divider=True) available_models = llama_stack_api.client.models.list() available_models = [model.identifier for model in available_models if model.model_type == "llm"] selected_model = st.selectbox( @@ -127,9 +141,11 @@ def rag_chat_page(): # Chat Interface if "messages" not in st.session_state: st.session_state.messages = [] + if "displayed_messages" not in st.session_state: + st.session_state.displayed_messages = [] # Display chat history - for message in st.session_state.messages: + for message in st.session_state.displayed_messages: with st.chat_message(message["role"]): st.markdown(message["content"]) @@ -161,14 +177,17 @@ def rag_chat_page(): ], ) - agent = create_agent() + if rag_mode == "Agent-based": + agent = create_agent() + if "agent_session_id" not in st.session_state: + st.session_state["agent_session_id"] = agent.create_session(session_name=f"rag_demo_{uuid.uuid4()}") - if "agent_session_id" not in st.session_state: - st.session_state["agent_session_id"] = agent.create_session(session_name=f"rag_demo_{uuid.uuid4()}") + session_id = st.session_state["agent_session_id"] - session_id = st.session_state["agent_session_id"] + def agent_process_prompt(prompt): + # Add user message to chat history + st.session_state.messages.append({"role": "user", "content": prompt}) - def process_prompt(prompt): # Send the prompt to the agent response = agent.create_turn( messages=[ @@ -197,11 +216,62 @@ def rag_chat_page(): message_placeholder.markdown(full_response) st.session_state.messages.append({"role": "assistant", "content": full_response}) + st.session_state.displayed_messages.append({"role": "assistant", "content": full_response}) + + def direct_process_prompt(prompt): + # Add the system prompt in the beginning of the conversation + if len(st.session_state.messages) == 0: + st.session_state.messages.append({"role": "system", "content": system_prompt}) + + # Query the vector DB + rag_response = llama_stack_api.client.tool_runtime.rag_tool.query( + content=prompt, vector_db_ids=list(selected_vector_dbs) + ) + prompt_context = rag_response.content + + with st.chat_message("assistant"): + retrieval_message_placeholder = st.empty() + message_placeholder = st.empty() + full_response = "" + retrieval_response = "" + + # Display the retrieved content + retrieval_response += str(prompt_context) + retrieval_message_placeholder.info(retrieval_response) + + # Construct the extended prompt + extended_prompt = f"Please answer the following query using the context below.\n\nCONTEXT:\n{prompt_context}\n\nQUERY:\n{prompt}" + + # Run inference directly + st.session_state.messages.append({"role": "user", "content": extended_prompt}) + response = llama_stack_api.client.inference.chat_completion( + messages=st.session_state.messages, + model_id=selected_model, + sampling_params={ + "strategy": strategy, + }, + stream=True, + ) + + # Display assistant response + for chunk in response: + response_delta = chunk.event.delta + if isinstance(response_delta, ToolCallDelta): + retrieval_response += response_delta.tool_call.replace("====", "").strip() + retrieval_message_placeholder.info(retrieval_response) + else: + full_response += chunk.event.delta.text + message_placeholder.markdown(full_response + "▌") + message_placeholder.markdown(full_response) + + response_dict = {"role": "assistant", "content": full_response, "stop_reason": "end_of_message"} + st.session_state.messages.append(response_dict) + st.session_state.displayed_messages.append(response_dict) # Chat input if prompt := st.chat_input("Ask a question about your documents"): # Add user message to chat history - st.session_state.messages.append({"role": "user", "content": prompt}) + st.session_state.displayed_messages.append({"role": "user", "content": prompt}) # Display user message with st.chat_message("user"): @@ -214,7 +284,10 @@ def rag_chat_page(): st.rerun() if "prompt" in st.session_state and st.session_state.prompt is not None: - process_prompt(st.session_state.prompt) + if rag_mode == "Agent-based": + agent_process_prompt(st.session_state.prompt) + else: # rag_mode == "Direct" + direct_process_prompt(st.session_state.prompt) st.session_state.prompt = None