feat: Support ReAct Agent on Tools Playground (#2012)

# What does this PR do?
ReAct prompting attempts to use the Thinking, Action, Observation loop
to improve the model's reasoning ability via prompt engineering.

With this PR, it now supports the various features in Streamlit's
playground:
1. Adding the selection box for choosing between Agent Type: normal,
ReAct.
2. Adding the Thinking, Action, Observation loop streamlit logic for
ReAct agent, as seen in many LLM clients.
3. Improving tool calling accuracies via ReAct prompting, e.g. using
web_search.


**Folded**
![react_output_folded
png](https://github.com/user-attachments/assets/bf1bdce7-e6ef-455d-b6b0-c22a64e9d5c1)

**Collapsed**

![react_output_collapsed](https://github.com/user-attachments/assets/cda2fc17-df0b-400d-971c-988de821f2a4)

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan

[Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.*]
Run the playground and uses reasoning prompts to see for yourself. Steps
to test the ReAct agent mode:
1. Setup a llama-stack server as
[getting_started](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html)
describes.
2. Setup your Web Search API keys under
`llama_stack/distribution/ui/modules/api.py`.
3. Run the streamlit playground and try ReAct agent, possibly with
`websearch`, with the command: `streamlit run
llama_stack/distribution/ui/app.py`.

## Test Process
Current results are demonstrated with `llama-3.2-3b-instruct`. Results
will vary with different models.

You should be seeing clear distinction with normal agent and ReAct
agent. Example prompts listed below:
1. Aside from the Apple Remote, what other devices can control the
program Apple Remote was originally designed to interact with?
2. What is the elevation range for the area that the eastern sector of
the Colorado orogeny extends into?

## Example Test Results

**Web search on AppleTV**
<img width="1440" alt="normal_output_appletv"
src="https://github.com/user-attachments/assets/bf6b3273-1c94-4976-8b4a-b2d82fe41330"
/>

<img width="1440" alt="react_output_appletv"
src="https://github.com/user-attachments/assets/687f1feb-88f4-4d32-93d5-5013d0d5fe25"
/>

**Web search on Colorado**
<img width="1440" alt="normal_output_colorado"
src="https://github.com/user-attachments/assets/10bd3ad4-f2ad-466d-9ce0-c66fccee40c1"
/>

<img width="1440" alt="react_output_colorado"
src="https://github.com/user-attachments/assets/39cfd82d-2be9-4e2f-9f90-a2c4840185f7"
/>

**Web search tool + MCP Slack server**
<img width="1250" alt="normal_output_search_slack png"
src="https://github.com/user-attachments/assets/72e88125-cdbf-4a90-bcb9-ab412c51d62d"
/>

<img width="1217" alt="react_output_search_slack"
src="https://github.com/user-attachments/assets/8ae04efb-a4fd-49f6-9465-37dbecb6b73e"
/>


![slack_screenshot](https://github.com/user-attachments/assets/bb70e669-6067-462a-bdf6-7aaac6ccbcef)
This commit is contained in:
Andy Xie 2025-04-25 11:01:51 -04:00 committed by GitHub
parent 121c73c2f5
commit f5dae0517c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -4,14 +4,23 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import enum
import json
import uuid
import streamlit as st
from llama_stack_client import Agent
from llama_stack_client.lib.agents.react.agent import ReActAgent
from llama_stack_client.lib.agents.react.tool_parser import ReActOutput
from llama_stack.distribution.ui.modules.api import llama_stack_api
class AgentType(enum.Enum):
REGULAR = "Regular"
REACT = "ReAct"
def tool_chat_page():
st.title("🛠 Tools")
@ -23,6 +32,7 @@ def tool_chat_page():
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
selected_vector_dbs = []
def reset_agent():
st.session_state.clear()
@ -82,12 +92,20 @@ def tool_chat_page():
st.markdown(f"{idx}. `{tool.split(':')[-1]}`")
st.subheader("Agent Configurations")
st.subheader("Agent Type")
agent_type = st.radio(
"Select Agent Type",
[AgentType.REGULAR, AgentType.REACT],
format_func=lambda x: x.value,
on_change=reset_agent,
)
max_tokens = st.slider(
"Max Tokens",
min_value=0,
max_value=4096,
value=512,
step=1,
step=64,
help="The maximum number of tokens to generate",
on_change=reset_agent,
)
@ -104,13 +122,27 @@ def tool_chat_page():
@st.cache_resource
def create_agent():
return Agent(
client,
model=model,
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
tools=toolgroup_selection,
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
)
if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT:
return ReActAgent(
client=client,
model=model,
tools=toolgroup_selection,
response_format={
"type": "json_schema",
"json_schema": ReActOutput.model_json_schema(),
},
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
)
else:
return Agent(
client,
model=model,
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
tools=toolgroup_selection,
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
)
st.session_state.agent_type = agent_type
agent = create_agent()
@ -139,6 +171,158 @@ def tool_chat_page():
)
def response_generator(turn_response):
if st.session_state.get("agent_type") == AgentType.REACT:
return _handle_react_response(turn_response)
else:
return _handle_regular_response(turn_response)
def _handle_react_response(turn_response):
current_step_content = ""
final_answer = None
tool_results = []
for response in turn_response:
if not hasattr(response.event, "payload"):
yield (
"\n\n🚨 :red[_Llama Stack server Error:_]\n"
"The response received is missing an expected `payload` attribute.\n"
"This could indicate a malformed response or an internal issue within the server.\n\n"
f"Error details: {response}"
)
return
payload = response.event.payload
if payload.event_type == "step_progress" and hasattr(payload.delta, "text"):
current_step_content += payload.delta.text
continue
if payload.event_type == "step_complete":
step_details = payload.step_details
if step_details.step_type == "inference":
yield from _process_inference_step(current_step_content, tool_results, final_answer)
current_step_content = ""
elif step_details.step_type == "tool_execution":
tool_results = _process_tool_execution(step_details, tool_results)
current_step_content = ""
else:
current_step_content = ""
if not final_answer and tool_results:
yield from _format_tool_results_summary(tool_results)
def _process_inference_step(current_step_content, tool_results, final_answer):
try:
react_output_data = json.loads(current_step_content)
thought = react_output_data.get("thought")
action = react_output_data.get("action")
answer = react_output_data.get("answer")
if answer and answer != "null" and answer is not None:
final_answer = answer
if thought:
with st.expander("🤔 Thinking...", expanded=False):
st.markdown(f":grey[__{thought}__]")
if action and isinstance(action, dict):
tool_name = action.get("tool_name")
tool_params = action.get("tool_params")
with st.expander(f'🛠 Action: Using tool "{tool_name}"', expanded=False):
st.json(tool_params)
if answer and answer != "null" and answer is not None:
yield f"\n\n✅ **Final Answer:**\n{answer}"
except json.JSONDecodeError:
yield f"\n\nFailed to parse ReAct step content:\n```json\n{current_step_content}\n```"
except Exception as e:
yield f"\n\nFailed to process ReAct step: {e}\n```json\n{current_step_content}\n```"
return final_answer
def _process_tool_execution(step_details, tool_results):
try:
if hasattr(step_details, "tool_responses") and step_details.tool_responses:
for tool_response in step_details.tool_responses:
tool_name = tool_response.tool_name
content = tool_response.content
tool_results.append((tool_name, content))
with st.expander(f'⚙️ Observation (Result from "{tool_name}")', expanded=False):
try:
parsed_content = json.loads(content)
st.json(parsed_content)
except json.JSONDecodeError:
st.code(content, language=None)
else:
with st.expander("⚙️ Observation", expanded=False):
st.markdown(":grey[_Tool execution step completed, but no response data found._]")
except Exception as e:
with st.expander("⚙️ Error in Tool Execution", expanded=False):
st.markdown(f":red[_Error processing tool execution: {str(e)}_]")
return tool_results
def _format_tool_results_summary(tool_results):
yield "\n\n**Here's what I found:**\n"
for tool_name, content in tool_results:
try:
parsed_content = json.loads(content)
if tool_name == "web_search" and "top_k" in parsed_content:
yield from _format_web_search_results(parsed_content)
elif "results" in parsed_content and isinstance(parsed_content["results"], list):
yield from _format_results_list(parsed_content["results"])
elif isinstance(parsed_content, dict) and len(parsed_content) > 0:
yield from _format_dict_results(parsed_content)
elif isinstance(parsed_content, list) and len(parsed_content) > 0:
yield from _format_list_results(parsed_content)
except json.JSONDecodeError:
yield f"\n**{tool_name}** was used but returned complex data. Check the observation for details.\n"
except (TypeError, AttributeError, KeyError, IndexError) as e:
print(f"Error processing {tool_name} result: {type(e).__name__}: {e}")
def _format_web_search_results(parsed_content):
for i, result in enumerate(parsed_content["top_k"], 1):
if i <= 3:
title = result.get("title", "Untitled")
url = result.get("url", "")
content_text = result.get("content", "").strip()
yield f"\n- **{title}**\n {content_text}\n [Source]({url})\n"
def _format_results_list(results):
for i, result in enumerate(results, 1):
if i <= 3:
if isinstance(result, dict):
name = result.get("name", result.get("title", "Result " + str(i)))
description = result.get("description", result.get("content", result.get("summary", "")))
yield f"\n- **{name}**\n {description}\n"
else:
yield f"\n- {result}\n"
def _format_dict_results(parsed_content):
yield "\n```\n"
for key, value in list(parsed_content.items())[:5]:
if isinstance(value, str) and len(value) < 100:
yield f"{key}: {value}\n"
else:
yield f"{key}: [Complex data]\n"
yield "```\n"
def _format_list_results(parsed_content):
yield "\n"
for _, item in enumerate(parsed_content[:3], 1):
if isinstance(item, str):
yield f"- {item}\n"
elif isinstance(item, dict) and "text" in item:
yield f"- {item['text']}\n"
elif isinstance(item, dict) and len(item) > 0:
first_value = next(iter(item.values()))
if isinstance(first_value, str) and len(first_value) < 100:
yield f"- {first_value}\n"
def _handle_regular_response(turn_response):
for response in turn_response:
if hasattr(response.event, "payload"):
print(response.event.payload)
@ -156,9 +340,9 @@ def tool_chat_page():
yield f"Error occurred in the Llama Stack Cluster: {response}"
with st.chat_message("assistant"):
response = st.write_stream(response_generator(turn_response))
response_content = st.write_stream(response_generator(turn_response))
st.session_state.messages.append({"role": "assistant", "content": response})
st.session_state.messages.append({"role": "assistant", "content": response_content})
tool_chat_page()