forked from phoenix-oss/llama-stack-mirror
feat: Updated playground rag to use session id for persistent conversation (#1870)
# What does this PR do? This PR updates the [playground RAG example](llama_stack/distribution/ui/page/playground/rag.py) so that the agent is able to use its builtin conversation history. Here we are using streamlit's `cache_resource` functionality to prevent the agent from re-initializing after every interaction as well as storing its session_id in the `session_state`. This allows the agent in the RAG example to behave more closely to how it works using the python-client directly. [//]: # (If resolving an issue, uncomment and update the line below) Closes #1869 ## Test Plan Without these changes, if you ask it "What is 2 + 2"? followed by the question "What did I just ask?" It will provide an obviously incorrect answer. With these changes, you can ask the same series of questions and it will provide the correct answer. [//]: # (## Documentation) Signed-off-by: Michael Clifford <mcliffor@redhat.com>
This commit is contained in:
parent
7b4eb0967e
commit
c6e93e32f6
1 changed files with 29 additions and 19 deletions
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
|
||||
import streamlit as st
|
||||
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
|
||||
|
||||
|
@ -102,8 +104,8 @@ def rag_chat_page():
|
|||
|
||||
# Add clear chat button to sidebar
|
||||
if st.button("Clear Chat", use_container_width=True):
|
||||
st.session_state.messages = []
|
||||
st.rerun()
|
||||
st.session_state.clear()
|
||||
st.cache_resource.clear()
|
||||
|
||||
# Chat Interface
|
||||
if "messages" not in st.session_state:
|
||||
|
@ -123,23 +125,31 @@ def rag_chat_page():
|
|||
else:
|
||||
strategy = {"type": "greedy"}
|
||||
|
||||
agent = Agent(
|
||||
llama_stack_api.client,
|
||||
model=selected_model,
|
||||
instructions=system_prompt,
|
||||
sampling_params={
|
||||
"strategy": strategy,
|
||||
},
|
||||
tools=[
|
||||
dict(
|
||||
name="builtin::rag/knowledge_search",
|
||||
args={
|
||||
"vector_db_ids": list(selected_vector_dbs),
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
session_id = agent.create_session("rag-session")
|
||||
@st.cache_resource
|
||||
def create_agent():
|
||||
return Agent(
|
||||
llama_stack_api.client,
|
||||
model=selected_model,
|
||||
instructions=system_prompt,
|
||||
sampling_params={
|
||||
"strategy": strategy,
|
||||
},
|
||||
tools=[
|
||||
dict(
|
||||
name="builtin::rag/knowledge_search",
|
||||
args={
|
||||
"vector_db_ids": list(selected_vector_dbs),
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
agent = create_agent()
|
||||
|
||||
if "agent_session_id" not in st.session_state:
|
||||
st.session_state["agent_session_id"] = agent.create_session(session_name=f"rag_demo_{uuid.uuid4()}")
|
||||
|
||||
session_id = st.session_state["agent_session_id"]
|
||||
|
||||
# Chat input
|
||||
if prompt := st.chat_input("Ask a question about your documents"):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue