mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
Maded changes on streamlit with ReAct agent mode
Signed-off-by: Andy Xie <anxie@redhat.com> fixing pre-commits Signed-off-by: Andy Xie <xiehandi@gmail.com> Ruff pre-commit fixes
This commit is contained in:
parent
fa5dfee07b
commit
92142fe0cd
1 changed files with 194 additions and 10 deletions
|
@ -4,14 +4,23 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import enum
|
||||
import json
|
||||
import uuid
|
||||
|
||||
import streamlit as st
|
||||
from llama_stack_client import Agent
|
||||
from llama_stack_client.lib.agents.react.agent import ReActAgent
|
||||
from llama_stack_client.lib.agents.react.tool_parser import ReActOutput
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
class AgentType(enum.Enum):
|
||||
REGULAR = "Regular"
|
||||
REACT = "ReAct"
|
||||
|
||||
|
||||
def tool_chat_page():
|
||||
st.title("🛠 Tools")
|
||||
|
||||
|
@ -23,6 +32,7 @@ def tool_chat_page():
|
|||
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
|
||||
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
|
||||
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
|
||||
selected_vector_dbs = []
|
||||
|
||||
def reset_agent():
|
||||
st.session_state.clear()
|
||||
|
@ -79,12 +89,20 @@ def tool_chat_page():
|
|||
st.json(active_tool_list)
|
||||
|
||||
st.subheader("Agent Configurations")
|
||||
st.subheader("Agent Type")
|
||||
agent_type = st.radio(
|
||||
"Select Agent Type",
|
||||
[AgentType.REGULAR, AgentType.REACT],
|
||||
format_func=lambda x: x.value,
|
||||
on_change=reset_agent,
|
||||
)
|
||||
|
||||
max_tokens = st.slider(
|
||||
"Max Tokens",
|
||||
min_value=0,
|
||||
max_value=4096,
|
||||
value=512,
|
||||
step=1,
|
||||
step=64,
|
||||
help="The maximum number of tokens to generate",
|
||||
on_change=reset_agent,
|
||||
)
|
||||
|
@ -101,13 +119,27 @@ def tool_chat_page():
|
|||
|
||||
@st.cache_resource
|
||||
def create_agent():
|
||||
return Agent(
|
||||
client,
|
||||
model=model,
|
||||
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
|
||||
tools=toolgroup_selection,
|
||||
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
||||
)
|
||||
if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT:
|
||||
return ReActAgent(
|
||||
client=client,
|
||||
model=model,
|
||||
tools=toolgroup_selection,
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": ReActOutput.model_json_schema(),
|
||||
},
|
||||
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
||||
)
|
||||
else:
|
||||
return Agent(
|
||||
client,
|
||||
model=model,
|
||||
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
|
||||
tools=toolgroup_selection,
|
||||
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
||||
)
|
||||
|
||||
st.session_state.agent_type = agent_type
|
||||
|
||||
agent = create_agent()
|
||||
|
||||
|
@ -136,6 +168,158 @@ def tool_chat_page():
|
|||
)
|
||||
|
||||
def response_generator(turn_response):
|
||||
if st.session_state.get("agent_type") == AgentType.REACT:
|
||||
return _handle_react_response(turn_response)
|
||||
else:
|
||||
return _handle_regular_response(turn_response)
|
||||
|
||||
def _handle_react_response(turn_response):
|
||||
current_step_content = ""
|
||||
final_answer = None
|
||||
tool_results = []
|
||||
|
||||
for response in turn_response:
|
||||
if not hasattr(response.event, "payload"):
|
||||
yield (
|
||||
"\n\n🚨 :red[_Llama Stack server Error:_]\n"
|
||||
"The response received is missing an expected `payload` attribute.\n"
|
||||
"This could indicate a malformed response or an internal issue within the server.\n\n"
|
||||
f"Error details: {response}"
|
||||
)
|
||||
return
|
||||
|
||||
payload = response.event.payload
|
||||
|
||||
if payload.event_type == "step_progress" and hasattr(payload.delta, "text"):
|
||||
current_step_content += payload.delta.text
|
||||
continue
|
||||
|
||||
if payload.event_type == "step_complete":
|
||||
step_details = payload.step_details
|
||||
|
||||
if step_details.step_type == "inference":
|
||||
yield from _process_inference_step(current_step_content, tool_results, final_answer)
|
||||
current_step_content = ""
|
||||
elif step_details.step_type == "tool_execution":
|
||||
tool_results = _process_tool_execution(step_details, tool_results)
|
||||
current_step_content = ""
|
||||
else:
|
||||
current_step_content = ""
|
||||
|
||||
if not final_answer and tool_results:
|
||||
yield from _format_tool_results_summary(tool_results)
|
||||
|
||||
def _process_inference_step(current_step_content, tool_results, final_answer):
|
||||
try:
|
||||
react_output_data = json.loads(current_step_content)
|
||||
thought = react_output_data.get("thought")
|
||||
action = react_output_data.get("action")
|
||||
answer = react_output_data.get("answer")
|
||||
|
||||
if answer and answer != "null" and answer is not None:
|
||||
final_answer = answer
|
||||
|
||||
if thought:
|
||||
with st.expander("🤔 Thinking...", expanded=False):
|
||||
st.markdown(f":grey[__{thought}__]")
|
||||
|
||||
if action and isinstance(action, dict):
|
||||
tool_name = action.get("tool_name")
|
||||
tool_params = action.get("tool_params")
|
||||
with st.expander(f'🛠 Action: Using tool "{tool_name}"', expanded=False):
|
||||
st.json(tool_params)
|
||||
|
||||
if answer and answer != "null" and answer is not None:
|
||||
yield f"\n\n✅ **Final Answer:**\n{answer}"
|
||||
|
||||
except json.JSONDecodeError:
|
||||
yield f"\n\nFailed to parse ReAct step content:\n```json\n{current_step_content}\n```"
|
||||
except Exception as e:
|
||||
yield f"\n\nFailed to process ReAct step: {e}\n```json\n{current_step_content}\n```"
|
||||
|
||||
return final_answer
|
||||
|
||||
def _process_tool_execution(step_details, tool_results):
|
||||
try:
|
||||
if hasattr(step_details, "tool_responses") and step_details.tool_responses:
|
||||
for tool_response in step_details.tool_responses:
|
||||
tool_name = tool_response.tool_name
|
||||
content = tool_response.content
|
||||
tool_results.append((tool_name, content))
|
||||
with st.expander(f'⚙️ Observation (Result from "{tool_name}")', expanded=False):
|
||||
try:
|
||||
parsed_content = json.loads(content)
|
||||
st.json(parsed_content)
|
||||
except json.JSONDecodeError:
|
||||
st.code(content, language=None)
|
||||
else:
|
||||
with st.expander("⚙️ Observation", expanded=False):
|
||||
st.markdown(":grey[_Tool execution step completed, but no response data found._]")
|
||||
except Exception as e:
|
||||
with st.expander("⚙️ Error in Tool Execution", expanded=False):
|
||||
st.markdown(f":red[_Error processing tool execution: {str(e)}_]")
|
||||
|
||||
return tool_results
|
||||
|
||||
def _format_tool_results_summary(tool_results):
|
||||
yield "\n\n**Here's what I found:**\n"
|
||||
for tool_name, content in tool_results:
|
||||
try:
|
||||
parsed_content = json.loads(content)
|
||||
|
||||
if tool_name == "web_search" and "top_k" in parsed_content:
|
||||
yield from _format_web_search_results(parsed_content)
|
||||
elif "results" in parsed_content and isinstance(parsed_content["results"], list):
|
||||
yield from _format_results_list(parsed_content["results"])
|
||||
elif isinstance(parsed_content, dict) and len(parsed_content) > 0:
|
||||
yield from _format_dict_results(parsed_content)
|
||||
elif isinstance(parsed_content, list) and len(parsed_content) > 0:
|
||||
yield from _format_list_results(parsed_content)
|
||||
except json.JSONDecodeError:
|
||||
yield f"\n**{tool_name}** was used but returned complex data. Check the observation for details.\n"
|
||||
except (TypeError, AttributeError, KeyError, IndexError) as e:
|
||||
print(f"Error processing {tool_name} result: {type(e).__name__}: {e}")
|
||||
|
||||
def _format_web_search_results(parsed_content):
|
||||
for i, result in enumerate(parsed_content["top_k"], 1):
|
||||
if i <= 3:
|
||||
title = result.get("title", "Untitled")
|
||||
url = result.get("url", "")
|
||||
content_text = result.get("content", "").strip()
|
||||
yield f"\n- **{title}**\n {content_text}\n [Source]({url})\n"
|
||||
|
||||
def _format_results_list(results):
|
||||
for i, result in enumerate(results, 1):
|
||||
if i <= 3:
|
||||
if isinstance(result, dict):
|
||||
name = result.get("name", result.get("title", "Result " + str(i)))
|
||||
description = result.get("description", result.get("content", result.get("summary", "")))
|
||||
yield f"\n- **{name}**\n {description}\n"
|
||||
else:
|
||||
yield f"\n- {result}\n"
|
||||
|
||||
def _format_dict_results(parsed_content):
|
||||
yield "\n```\n"
|
||||
for key, value in list(parsed_content.items())[:5]:
|
||||
if isinstance(value, str) and len(value) < 100:
|
||||
yield f"{key}: {value}\n"
|
||||
else:
|
||||
yield f"{key}: [Complex data]\n"
|
||||
yield "```\n"
|
||||
|
||||
def _format_list_results(parsed_content):
|
||||
yield "\n"
|
||||
for _, item in enumerate(parsed_content[:3], 1):
|
||||
if isinstance(item, str):
|
||||
yield f"- {item}\n"
|
||||
elif isinstance(item, dict) and "text" in item:
|
||||
yield f"- {item['text']}\n"
|
||||
elif isinstance(item, dict) and len(item) > 0:
|
||||
first_value = next(iter(item.values()))
|
||||
if isinstance(first_value, str) and len(first_value) < 100:
|
||||
yield f"- {first_value}\n"
|
||||
|
||||
def _handle_regular_response(turn_response):
|
||||
for response in turn_response:
|
||||
if hasattr(response.event, "payload"):
|
||||
print(response.event.payload)
|
||||
|
@ -153,9 +337,9 @@ def tool_chat_page():
|
|||
yield f"Error occurred in the Llama Stack Cluster: {response}"
|
||||
|
||||
with st.chat_message("assistant"):
|
||||
response = st.write_stream(response_generator(turn_response))
|
||||
response_content = st.write_stream(response_generator(turn_response))
|
||||
|
||||
st.session_state.messages.append({"role": "assistant", "content": response})
|
||||
st.session_state.messages.append({"role": "assistant", "content": response_content})
|
||||
|
||||
|
||||
tool_chat_page()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue