forked from phoenix-oss/llama-stack-mirror
		
	# What does this PR do? Now, tool outputs and retrieved chunks from the vector DB (i.e., everything except for the actual model reply) are hidden under an expander form when presented to the user. # Test Plan Navigate to the RAG page in the Playground UI.
		
			
				
	
	
		
			301 lines
		
	
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			301 lines
		
	
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import uuid
 | |
| 
 | |
| import streamlit as st
 | |
| from llama_stack_client import Agent, AgentEventLogger, RAGDocument
 | |
| 
 | |
| from llama_stack.apis.common.content_types import ToolCallDelta
 | |
| from llama_stack.distribution.ui.modules.api import llama_stack_api
 | |
| from llama_stack.distribution.ui.modules.utils import data_url_from_file
 | |
| 
 | |
| 
 | |
| def rag_chat_page():
 | |
|     st.title("🦙 RAG")
 | |
| 
 | |
|     def reset_agent_and_chat():
 | |
|         st.session_state.clear()
 | |
|         st.cache_resource.clear()
 | |
| 
 | |
|     def should_disable_input():
 | |
|         return "displayed_messages" in st.session_state and len(st.session_state.displayed_messages) > 0
 | |
| 
 | |
|     def log_message(message):
 | |
|         with st.chat_message(message["role"]):
 | |
|             if "tool_output" in message and message["tool_output"]:
 | |
|                 with st.expander(label="Tool Output", expanded=False, icon="🛠"):
 | |
|                     st.write(message["tool_output"])
 | |
|             st.markdown(message["content"])
 | |
| 
 | |
|     with st.sidebar:
 | |
|         # File/Directory Upload Section
 | |
|         st.subheader("Upload Documents", divider=True)
 | |
|         uploaded_files = st.file_uploader(
 | |
|             "Upload file(s) or directory",
 | |
|             accept_multiple_files=True,
 | |
|             type=["txt", "pdf", "doc", "docx"],  # Add more file types as needed
 | |
|         )
 | |
|         # Process uploaded files
 | |
|         if uploaded_files:
 | |
|             st.success(f"Successfully uploaded {len(uploaded_files)} files")
 | |
|             # Add memory bank name input field
 | |
|             vector_db_name = st.text_input(
 | |
|                 "Document Collection Name",
 | |
|                 value="rag_vector_db",
 | |
|                 help="Enter a unique identifier for this document collection",
 | |
|             )
 | |
|             if st.button("Create Document Collection"):
 | |
|                 documents = [
 | |
|                     RAGDocument(
 | |
|                         document_id=uploaded_file.name,
 | |
|                         content=data_url_from_file(uploaded_file),
 | |
|                     )
 | |
|                     for i, uploaded_file in enumerate(uploaded_files)
 | |
|                 ]
 | |
| 
 | |
|                 providers = llama_stack_api.client.providers.list()
 | |
|                 vector_io_provider = None
 | |
| 
 | |
|                 for x in providers:
 | |
|                     if x.api == "vector_io":
 | |
|                         vector_io_provider = x.provider_id
 | |
| 
 | |
|                 llama_stack_api.client.vector_dbs.register(
 | |
|                     vector_db_id=vector_db_name,  # Use the user-provided name
 | |
|                     embedding_dimension=384,
 | |
|                     embedding_model="all-MiniLM-L6-v2",
 | |
|                     provider_id=vector_io_provider,
 | |
|                 )
 | |
| 
 | |
|                 # insert documents using the custom vector db name
 | |
|                 llama_stack_api.client.tool_runtime.rag_tool.insert(
 | |
|                     vector_db_id=vector_db_name,  # Use the user-provided name
 | |
|                     documents=documents,
 | |
|                     chunk_size_in_tokens=512,
 | |
|                 )
 | |
|                 st.success("Vector database created successfully!")
 | |
| 
 | |
|         st.subheader("RAG Parameters", divider=True)
 | |
| 
 | |
|         rag_mode = st.radio(
 | |
|             "RAG mode",
 | |
|             ["Direct", "Agent-based"],
 | |
|             captions=[
 | |
|                 "RAG is performed by directly retrieving the information and augmenting the user query",
 | |
|                 "RAG is performed by an agent activating a dedicated knowledge search tool.",
 | |
|             ],
 | |
|             on_change=reset_agent_and_chat,
 | |
|             disabled=should_disable_input(),
 | |
|         )
 | |
| 
 | |
|         # select memory banks
 | |
|         vector_dbs = llama_stack_api.client.vector_dbs.list()
 | |
|         vector_dbs = [vector_db.identifier for vector_db in vector_dbs]
 | |
|         selected_vector_dbs = st.multiselect(
 | |
|             label="Select Document Collections to use in RAG queries",
 | |
|             options=vector_dbs,
 | |
|             on_change=reset_agent_and_chat,
 | |
|             disabled=should_disable_input(),
 | |
|         )
 | |
| 
 | |
|         st.subheader("Inference Parameters", divider=True)
 | |
|         available_models = llama_stack_api.client.models.list()
 | |
|         available_models = [model.identifier for model in available_models if model.model_type == "llm"]
 | |
|         selected_model = st.selectbox(
 | |
|             label="Choose a model",
 | |
|             options=available_models,
 | |
|             index=0,
 | |
|             on_change=reset_agent_and_chat,
 | |
|             disabled=should_disable_input(),
 | |
|         )
 | |
|         system_prompt = st.text_area(
 | |
|             "System Prompt",
 | |
|             value="You are a helpful assistant. ",
 | |
|             help="Initial instructions given to the AI to set its behavior and context",
 | |
|             on_change=reset_agent_and_chat,
 | |
|             disabled=should_disable_input(),
 | |
|         )
 | |
|         temperature = st.slider(
 | |
|             "Temperature",
 | |
|             min_value=0.0,
 | |
|             max_value=1.0,
 | |
|             value=0.0,
 | |
|             step=0.1,
 | |
|             help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
 | |
|             on_change=reset_agent_and_chat,
 | |
|             disabled=should_disable_input(),
 | |
|         )
 | |
| 
 | |
|         top_p = st.slider(
 | |
|             "Top P",
 | |
|             min_value=0.0,
 | |
|             max_value=1.0,
 | |
|             value=0.95,
 | |
|             step=0.1,
 | |
|             on_change=reset_agent_and_chat,
 | |
|             disabled=should_disable_input(),
 | |
|         )
 | |
| 
 | |
|         # Add clear chat button to sidebar
 | |
|         if st.button("Clear Chat", use_container_width=True):
 | |
|             reset_agent_and_chat()
 | |
|             st.rerun()
 | |
| 
 | |
|     # Chat Interface
 | |
|     if "messages" not in st.session_state:
 | |
|         st.session_state.messages = []
 | |
|     if "displayed_messages" not in st.session_state:
 | |
|         st.session_state.displayed_messages = []
 | |
| 
 | |
|     # Display chat history
 | |
|     for message in st.session_state.displayed_messages:
 | |
|         log_message(message)
 | |
| 
 | |
|     if temperature > 0.0:
 | |
|         strategy = {
 | |
|             "type": "top_p",
 | |
|             "temperature": temperature,
 | |
|             "top_p": top_p,
 | |
|         }
 | |
|     else:
 | |
|         strategy = {"type": "greedy"}
 | |
| 
 | |
|     @st.cache_resource
 | |
|     def create_agent():
 | |
|         return Agent(
 | |
|             llama_stack_api.client,
 | |
|             model=selected_model,
 | |
|             instructions=system_prompt,
 | |
|             sampling_params={
 | |
|                 "strategy": strategy,
 | |
|             },
 | |
|             tools=[
 | |
|                 dict(
 | |
|                     name="builtin::rag/knowledge_search",
 | |
|                     args={
 | |
|                         "vector_db_ids": list(selected_vector_dbs),
 | |
|                     },
 | |
|                 )
 | |
|             ],
 | |
|         )
 | |
| 
 | |
|     if rag_mode == "Agent-based":
 | |
|         agent = create_agent()
 | |
|         if "agent_session_id" not in st.session_state:
 | |
|             st.session_state["agent_session_id"] = agent.create_session(session_name=f"rag_demo_{uuid.uuid4()}")
 | |
| 
 | |
|         session_id = st.session_state["agent_session_id"]
 | |
| 
 | |
|     def agent_process_prompt(prompt):
 | |
|         # Add user message to chat history
 | |
|         st.session_state.messages.append({"role": "user", "content": prompt})
 | |
| 
 | |
|         # Send the prompt to the agent
 | |
|         response = agent.create_turn(
 | |
|             messages=[
 | |
|                 {
 | |
|                     "role": "user",
 | |
|                     "content": prompt,
 | |
|                 }
 | |
|             ],
 | |
|             session_id=session_id,
 | |
|         )
 | |
| 
 | |
|         # Display assistant response
 | |
|         with st.chat_message("assistant"):
 | |
|             retrieval_message_placeholder = st.expander(label="Tool Output", expanded=False, icon="🛠")
 | |
|             message_placeholder = st.empty()
 | |
|             full_response = ""
 | |
|             retrieval_response = ""
 | |
|             for log in AgentEventLogger().log(response):
 | |
|                 log.print()
 | |
|                 if log.role == "tool_execution":
 | |
|                     retrieval_response += log.content.replace("====", "").strip()
 | |
|                     retrieval_message_placeholder.write(retrieval_response)
 | |
|                 else:
 | |
|                     full_response += log.content
 | |
|                     message_placeholder.markdown(full_response + "▌")
 | |
|             message_placeholder.markdown(full_response)
 | |
| 
 | |
|             st.session_state.messages.append({"role": "assistant", "content": full_response})
 | |
|             st.session_state.displayed_messages.append(
 | |
|                 {"role": "assistant", "content": full_response, "tool_output": retrieval_response}
 | |
|             )
 | |
| 
 | |
|     def direct_process_prompt(prompt):
 | |
|         # Add the system prompt in the beginning of the conversation
 | |
|         if len(st.session_state.messages) == 0:
 | |
|             st.session_state.messages.append({"role": "system", "content": system_prompt})
 | |
| 
 | |
|         # Query the vector DB
 | |
|         rag_response = llama_stack_api.client.tool_runtime.rag_tool.query(
 | |
|             content=prompt, vector_db_ids=list(selected_vector_dbs)
 | |
|         )
 | |
|         prompt_context = rag_response.content
 | |
| 
 | |
|         with st.chat_message("assistant"):
 | |
|             with st.expander(label="Retrieval Output", expanded=False):
 | |
|                 st.write(prompt_context)
 | |
| 
 | |
|             retrieval_message_placeholder = st.empty()
 | |
|             message_placeholder = st.empty()
 | |
|             full_response = ""
 | |
|             retrieval_response = ""
 | |
| 
 | |
|             # Construct the extended prompt
 | |
|             extended_prompt = f"Please answer the following query using the context below.\n\nCONTEXT:\n{prompt_context}\n\nQUERY:\n{prompt}"
 | |
| 
 | |
|             # Run inference directly
 | |
|             st.session_state.messages.append({"role": "user", "content": extended_prompt})
 | |
|             response = llama_stack_api.client.inference.chat_completion(
 | |
|                 messages=st.session_state.messages,
 | |
|                 model_id=selected_model,
 | |
|                 sampling_params={
 | |
|                     "strategy": strategy,
 | |
|                 },
 | |
|                 stream=True,
 | |
|             )
 | |
| 
 | |
|             # Display assistant response
 | |
|             for chunk in response:
 | |
|                 response_delta = chunk.event.delta
 | |
|                 if isinstance(response_delta, ToolCallDelta):
 | |
|                     retrieval_response += response_delta.tool_call.replace("====", "").strip()
 | |
|                     retrieval_message_placeholder.info(retrieval_response)
 | |
|                 else:
 | |
|                     full_response += chunk.event.delta.text
 | |
|                     message_placeholder.markdown(full_response + "▌")
 | |
|             message_placeholder.markdown(full_response)
 | |
| 
 | |
|         response_dict = {"role": "assistant", "content": full_response, "stop_reason": "end_of_message"}
 | |
|         st.session_state.messages.append(response_dict)
 | |
|         st.session_state.displayed_messages.append(response_dict)
 | |
| 
 | |
|     # Chat input
 | |
|     if prompt := st.chat_input("Ask a question about your documents"):
 | |
|         # Add user message to chat history
 | |
|         st.session_state.displayed_messages.append({"role": "user", "content": prompt})
 | |
| 
 | |
|         # Display user message
 | |
|         with st.chat_message("user"):
 | |
|             st.markdown(prompt)
 | |
| 
 | |
|         # store the prompt to process it after page refresh
 | |
|         st.session_state.prompt = prompt
 | |
| 
 | |
|         # force page refresh to disable the settings widgets
 | |
|         st.rerun()
 | |
| 
 | |
|     if "prompt" in st.session_state and st.session_state.prompt is not None:
 | |
|         if rag_mode == "Agent-based":
 | |
|             agent_process_prompt(st.session_state.prompt)
 | |
|         else:  # rag_mode == "Direct"
 | |
|             direct_process_prompt(st.session_state.prompt)
 | |
|         st.session_state.prompt = None
 | |
| 
 | |
| 
 | |
| rag_chat_page()
 |