diff --git a/llama_stack/distribution/ui/page/playground/tools.py b/llama_stack/distribution/ui/page/playground/tools.py index bc2e8975f..fac6ef52a 100644 --- a/llama_stack/distribution/ui/page/playground/tools.py +++ b/llama_stack/distribution/ui/page/playground/tools.py @@ -37,6 +37,17 @@ def tool_chat_page(): label="Available ToolGroups", options=builtin_tools_list, selection_mode="multi", on_change=reset_agent ) + if "builtin::rag" in toolgroup_selection: + vector_dbs = llama_stack_api.client.vector_dbs.list() or [] + if not vector_dbs: + st.info("No vector databases available for selection.") + vector_dbs = [vector_db.identifier for vector_db in vector_dbs] + selected_vector_dbs = st.multiselect( + label="Select Document Collections to use in RAG queries", + options=vector_dbs, + on_change=reset_agent, + ) + st.subheader("MCP Servers") mcp_selection = st.pills( label="Available MCP Servers", options=mcp_tools_list, selection_mode="multi", on_change=reset_agent @@ -67,6 +78,16 @@ def tool_chat_page(): on_change=reset_agent, ) + for i, tool_name in enumerate(toolgroup_selection): + if tool_name == "builtin::rag": + tool_dict = dict( + name="builtin::rag", + args={ + "vector_db_ids": list(selected_vector_dbs), + }, + ) + toolgroup_selection[i] = tool_dict + @st.cache_resource def create_agent(): return Agent(