diff --git a/llama_stack/distribution/ui/page/playground/tools.py b/llama_stack/distribution/ui/page/playground/tools.py index 5e19c1e4f..6c6a9fcfd 100644 --- a/llama_stack/distribution/ui/page/playground/tools.py +++ b/llama_stack/distribution/ui/page/playground/tools.py @@ -4,14 +4,23 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import enum +import json import uuid import streamlit as st from llama_stack_client import Agent +from llama_stack_client.lib.agents.react.agent import ReActAgent +from llama_stack_client.lib.agents.react.tool_parser import ReActOutput from llama_stack.distribution.ui.modules.api import llama_stack_api +class AgentType(enum.Enum): + REGULAR = "Regular" + REACT = "ReAct" + + def tool_chat_page(): st.title("šŸ›  Tools") @@ -23,6 +32,7 @@ def tool_chat_page(): tool_groups_list = [tool_group.identifier for tool_group in tool_groups] mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")] builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")] + selected_vector_dbs = [] def reset_agent(): st.session_state.clear() @@ -82,12 +92,20 @@ def tool_chat_page(): st.markdown(f"{idx}. `{tool.split(':')[-1]}`") st.subheader("Agent Configurations") + st.subheader("Agent Type") + agent_type = st.radio( + "Select Agent Type", + [AgentType.REGULAR, AgentType.REACT], + format_func=lambda x: x.value, + on_change=reset_agent, + ) + max_tokens = st.slider( "Max Tokens", min_value=0, max_value=4096, value=512, - step=1, + step=64, help="The maximum number of tokens to generate", on_change=reset_agent, ) @@ -104,13 +122,27 @@ def tool_chat_page(): @st.cache_resource def create_agent(): - return Agent( - client, - model=model, - instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.", - tools=toolgroup_selection, - sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens}, - ) + if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT: + return ReActAgent( + client=client, + model=model, + tools=toolgroup_selection, + response_format={ + "type": "json_schema", + "json_schema": ReActOutput.model_json_schema(), + }, + sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens}, + ) + else: + return Agent( + client, + model=model, + instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.", + tools=toolgroup_selection, + sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens}, + ) + + st.session_state.agent_type = agent_type agent = create_agent() @@ -139,6 +171,158 @@ def tool_chat_page(): ) def response_generator(turn_response): + if st.session_state.get("agent_type") == AgentType.REACT: + return _handle_react_response(turn_response) + else: + return _handle_regular_response(turn_response) + + def _handle_react_response(turn_response): + current_step_content = "" + final_answer = None + tool_results = [] + + for response in turn_response: + if not hasattr(response.event, "payload"): + yield ( + "\n\n🚨 :red[_Llama Stack server Error:_]\n" + "The response received is missing an expected `payload` attribute.\n" + "This could indicate a malformed response or an internal issue within the server.\n\n" + f"Error details: {response}" + ) + return + + payload = response.event.payload + + if payload.event_type == "step_progress" and hasattr(payload.delta, "text"): + current_step_content += payload.delta.text + continue + + if payload.event_type == "step_complete": + step_details = payload.step_details + + if step_details.step_type == "inference": + yield from _process_inference_step(current_step_content, tool_results, final_answer) + current_step_content = "" + elif step_details.step_type == "tool_execution": + tool_results = _process_tool_execution(step_details, tool_results) + current_step_content = "" + else: + current_step_content = "" + + if not final_answer and tool_results: + yield from _format_tool_results_summary(tool_results) + + def _process_inference_step(current_step_content, tool_results, final_answer): + try: + react_output_data = json.loads(current_step_content) + thought = react_output_data.get("thought") + action = react_output_data.get("action") + answer = react_output_data.get("answer") + + if answer and answer != "null" and answer is not None: + final_answer = answer + + if thought: + with st.expander("šŸ¤” Thinking...", expanded=False): + st.markdown(f":grey[__{thought}__]") + + if action and isinstance(action, dict): + tool_name = action.get("tool_name") + tool_params = action.get("tool_params") + with st.expander(f'šŸ›  Action: Using tool "{tool_name}"', expanded=False): + st.json(tool_params) + + if answer and answer != "null" and answer is not None: + yield f"\n\nāœ… **Final Answer:**\n{answer}" + + except json.JSONDecodeError: + yield f"\n\nFailed to parse ReAct step content:\n```json\n{current_step_content}\n```" + except Exception as e: + yield f"\n\nFailed to process ReAct step: {e}\n```json\n{current_step_content}\n```" + + return final_answer + + def _process_tool_execution(step_details, tool_results): + try: + if hasattr(step_details, "tool_responses") and step_details.tool_responses: + for tool_response in step_details.tool_responses: + tool_name = tool_response.tool_name + content = tool_response.content + tool_results.append((tool_name, content)) + with st.expander(f'āš™ļø Observation (Result from "{tool_name}")', expanded=False): + try: + parsed_content = json.loads(content) + st.json(parsed_content) + except json.JSONDecodeError: + st.code(content, language=None) + else: + with st.expander("āš™ļø Observation", expanded=False): + st.markdown(":grey[_Tool execution step completed, but no response data found._]") + except Exception as e: + with st.expander("āš™ļø Error in Tool Execution", expanded=False): + st.markdown(f":red[_Error processing tool execution: {str(e)}_]") + + return tool_results + + def _format_tool_results_summary(tool_results): + yield "\n\n**Here's what I found:**\n" + for tool_name, content in tool_results: + try: + parsed_content = json.loads(content) + + if tool_name == "web_search" and "top_k" in parsed_content: + yield from _format_web_search_results(parsed_content) + elif "results" in parsed_content and isinstance(parsed_content["results"], list): + yield from _format_results_list(parsed_content["results"]) + elif isinstance(parsed_content, dict) and len(parsed_content) > 0: + yield from _format_dict_results(parsed_content) + elif isinstance(parsed_content, list) and len(parsed_content) > 0: + yield from _format_list_results(parsed_content) + except json.JSONDecodeError: + yield f"\n**{tool_name}** was used but returned complex data. Check the observation for details.\n" + except (TypeError, AttributeError, KeyError, IndexError) as e: + print(f"Error processing {tool_name} result: {type(e).__name__}: {e}") + + def _format_web_search_results(parsed_content): + for i, result in enumerate(parsed_content["top_k"], 1): + if i <= 3: + title = result.get("title", "Untitled") + url = result.get("url", "") + content_text = result.get("content", "").strip() + yield f"\n- **{title}**\n {content_text}\n [Source]({url})\n" + + def _format_results_list(results): + for i, result in enumerate(results, 1): + if i <= 3: + if isinstance(result, dict): + name = result.get("name", result.get("title", "Result " + str(i))) + description = result.get("description", result.get("content", result.get("summary", ""))) + yield f"\n- **{name}**\n {description}\n" + else: + yield f"\n- {result}\n" + + def _format_dict_results(parsed_content): + yield "\n```\n" + for key, value in list(parsed_content.items())[:5]: + if isinstance(value, str) and len(value) < 100: + yield f"{key}: {value}\n" + else: + yield f"{key}: [Complex data]\n" + yield "```\n" + + def _format_list_results(parsed_content): + yield "\n" + for _, item in enumerate(parsed_content[:3], 1): + if isinstance(item, str): + yield f"- {item}\n" + elif isinstance(item, dict) and "text" in item: + yield f"- {item['text']}\n" + elif isinstance(item, dict) and len(item) > 0: + first_value = next(iter(item.values())) + if isinstance(first_value, str) and len(first_value) < 100: + yield f"- {first_value}\n" + + def _handle_regular_response(turn_response): for response in turn_response: if hasattr(response.event, "payload"): print(response.event.payload) @@ -156,9 +340,9 @@ def tool_chat_page(): yield f"Error occurred in the Llama Stack Cluster: {response}" with st.chat_message("assistant"): - response = st.write_stream(response_generator(turn_response)) + response_content = st.write_stream(response_generator(turn_response)) - st.session_state.messages.append({"role": "assistant", "content": response}) + st.session_state.messages.append({"role": "assistant", "content": response_content}) tool_chat_page()