forked from phoenix-oss/llama-stack-mirror
# What does this PR do? This PR fixes two issues with the RAG page of the Playground UI: 1. When the user modifies a configurable setting via a widget (e.g., system prompt, temperature, etc.), the agent is not recreated. Thus, the change has no effect and the user gets no indication of that. 2. After the first issue is fixed, it becomes possible to recreate the agent mid-conversation or even mid-generation. To mitigate this, widgets related to agent configuration are now disabled when a conversation is in progress (i.e., when the chat is non-empty). They are automatically enabled again when the user resets the chat history. ## Test Plan - Launch the Playground and go to the RAG page; - Select the vector DB ID; - Send a message to the agent via the chat; - The widgets in charge of the agent parameters will become disabled at this point; - Send a second message asking the model about the content of the first message; - The reply will indicate that the two messages were sent over the same session, that is, the agent was not recreated; - Click the 'Clear Chat' button; - All widgets will be enabled and a new agent will be created (which can be validated by sending another message).
221 lines
7.6 KiB
Python
221 lines
7.6 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import uuid
|
|
|
|
import streamlit as st
|
|
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
|
|
|
|
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
|
from llama_stack.distribution.ui.modules.utils import data_url_from_file
|
|
|
|
|
|
def rag_chat_page():
|
|
st.title("🦙 RAG")
|
|
|
|
def reset_agent_and_chat():
|
|
st.session_state.clear()
|
|
st.cache_resource.clear()
|
|
|
|
def should_disable_input():
|
|
return "messages" in st.session_state and len(st.session_state.messages) > 0
|
|
|
|
with st.sidebar:
|
|
# File/Directory Upload Section
|
|
st.subheader("Upload Documents")
|
|
uploaded_files = st.file_uploader(
|
|
"Upload file(s) or directory",
|
|
accept_multiple_files=True,
|
|
type=["txt", "pdf", "doc", "docx"], # Add more file types as needed
|
|
)
|
|
# Process uploaded files
|
|
if uploaded_files:
|
|
st.success(f"Successfully uploaded {len(uploaded_files)} files")
|
|
# Add memory bank name input field
|
|
vector_db_name = st.text_input(
|
|
"Vector Database Name",
|
|
value="rag_vector_db",
|
|
help="Enter a unique identifier for this vector database",
|
|
)
|
|
if st.button("Create Vector Database"):
|
|
documents = [
|
|
RAGDocument(
|
|
document_id=uploaded_file.name,
|
|
content=data_url_from_file(uploaded_file),
|
|
)
|
|
for i, uploaded_file in enumerate(uploaded_files)
|
|
]
|
|
|
|
providers = llama_stack_api.client.providers.list()
|
|
vector_io_provider = None
|
|
|
|
for x in providers:
|
|
if x.api == "vector_io":
|
|
vector_io_provider = x.provider_id
|
|
|
|
llama_stack_api.client.vector_dbs.register(
|
|
vector_db_id=vector_db_name, # Use the user-provided name
|
|
embedding_dimension=384,
|
|
embedding_model="all-MiniLM-L6-v2",
|
|
provider_id=vector_io_provider,
|
|
)
|
|
|
|
# insert documents using the custom vector db name
|
|
llama_stack_api.client.tool_runtime.rag_tool.insert(
|
|
vector_db_id=vector_db_name, # Use the user-provided name
|
|
documents=documents,
|
|
chunk_size_in_tokens=512,
|
|
)
|
|
st.success("Vector database created successfully!")
|
|
|
|
st.subheader("Configure Agent")
|
|
# select memory banks
|
|
vector_dbs = llama_stack_api.client.vector_dbs.list()
|
|
vector_dbs = [vector_db.identifier for vector_db in vector_dbs]
|
|
selected_vector_dbs = st.multiselect(
|
|
label="Select Vector Databases",
|
|
options=vector_dbs,
|
|
on_change=reset_agent_and_chat,
|
|
disabled=should_disable_input(),
|
|
)
|
|
|
|
available_models = llama_stack_api.client.models.list()
|
|
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
|
|
selected_model = st.selectbox(
|
|
label="Choose a model",
|
|
options=available_models,
|
|
index=0,
|
|
on_change=reset_agent_and_chat,
|
|
disabled=should_disable_input(),
|
|
)
|
|
system_prompt = st.text_area(
|
|
"System Prompt",
|
|
value="You are a helpful assistant. ",
|
|
help="Initial instructions given to the AI to set its behavior and context",
|
|
on_change=reset_agent_and_chat,
|
|
disabled=should_disable_input(),
|
|
)
|
|
temperature = st.slider(
|
|
"Temperature",
|
|
min_value=0.0,
|
|
max_value=1.0,
|
|
value=0.0,
|
|
step=0.1,
|
|
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
|
|
on_change=reset_agent_and_chat,
|
|
disabled=should_disable_input(),
|
|
)
|
|
|
|
top_p = st.slider(
|
|
"Top P",
|
|
min_value=0.0,
|
|
max_value=1.0,
|
|
value=0.95,
|
|
step=0.1,
|
|
on_change=reset_agent_and_chat,
|
|
disabled=should_disable_input(),
|
|
)
|
|
|
|
# Add clear chat button to sidebar
|
|
if st.button("Clear Chat", use_container_width=True):
|
|
reset_agent_and_chat()
|
|
st.rerun()
|
|
|
|
# Chat Interface
|
|
if "messages" not in st.session_state:
|
|
st.session_state.messages = []
|
|
|
|
# Display chat history
|
|
for message in st.session_state.messages:
|
|
with st.chat_message(message["role"]):
|
|
st.markdown(message["content"])
|
|
|
|
if temperature > 0.0:
|
|
strategy = {
|
|
"type": "top_p",
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
}
|
|
else:
|
|
strategy = {"type": "greedy"}
|
|
|
|
@st.cache_resource
|
|
def create_agent():
|
|
return Agent(
|
|
llama_stack_api.client,
|
|
model=selected_model,
|
|
instructions=system_prompt,
|
|
sampling_params={
|
|
"strategy": strategy,
|
|
},
|
|
tools=[
|
|
dict(
|
|
name="builtin::rag/knowledge_search",
|
|
args={
|
|
"vector_db_ids": list(selected_vector_dbs),
|
|
},
|
|
)
|
|
],
|
|
)
|
|
|
|
agent = create_agent()
|
|
|
|
if "agent_session_id" not in st.session_state:
|
|
st.session_state["agent_session_id"] = agent.create_session(session_name=f"rag_demo_{uuid.uuid4()}")
|
|
|
|
session_id = st.session_state["agent_session_id"]
|
|
|
|
def process_prompt(prompt):
|
|
# Send the prompt to the agent
|
|
response = agent.create_turn(
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": prompt,
|
|
}
|
|
],
|
|
session_id=session_id,
|
|
)
|
|
|
|
# Display assistant response
|
|
with st.chat_message("assistant"):
|
|
retrieval_message_placeholder = st.empty()
|
|
message_placeholder = st.empty()
|
|
full_response = ""
|
|
retrieval_response = ""
|
|
for log in AgentEventLogger().log(response):
|
|
log.print()
|
|
if log.role == "tool_execution":
|
|
retrieval_response += log.content.replace("====", "").strip()
|
|
retrieval_message_placeholder.info(retrieval_response)
|
|
else:
|
|
full_response += log.content
|
|
message_placeholder.markdown(full_response + "▌")
|
|
message_placeholder.markdown(full_response)
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
|
|
|
# Chat input
|
|
if prompt := st.chat_input("Ask a question about your documents"):
|
|
# Add user message to chat history
|
|
st.session_state.messages.append({"role": "user", "content": prompt})
|
|
|
|
# Display user message
|
|
with st.chat_message("user"):
|
|
st.markdown(prompt)
|
|
|
|
# store the prompt to process it after page refresh
|
|
st.session_state.prompt = prompt
|
|
|
|
# force page refresh to disable the settings widgets
|
|
st.rerun()
|
|
|
|
if "prompt" in st.session_state and st.session_state.prompt is not None:
|
|
process_prompt(st.session_state.prompt)
|
|
st.session_state.prompt = None
|
|
|
|
|
|
rag_chat_page()
|