From b54c0c61aa6d0089c1b7223d2d91df899cca286c Mon Sep 17 00:00:00 2001 From: Michael Clifford Date: Tue, 15 Apr 2025 14:39:40 -0400 Subject: [PATCH] updated tools playground to allow vdb selection Signed-off-by: Michael Clifford --- .../distribution/ui/page/playground/tools.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/llama_stack/distribution/ui/page/playground/tools.py b/llama_stack/distribution/ui/page/playground/tools.py index bc2e8975f..fac6ef52a 100644 --- a/llama_stack/distribution/ui/page/playground/tools.py +++ b/llama_stack/distribution/ui/page/playground/tools.py @@ -37,6 +37,17 @@ def tool_chat_page(): label="Available ToolGroups", options=builtin_tools_list, selection_mode="multi", on_change=reset_agent ) + if "builtin::rag" in toolgroup_selection: + vector_dbs = llama_stack_api.client.vector_dbs.list() or [] + if not vector_dbs: + st.info("No vector databases available for selection.") + vector_dbs = [vector_db.identifier for vector_db in vector_dbs] + selected_vector_dbs = st.multiselect( + label="Select Document Collections to use in RAG queries", + options=vector_dbs, + on_change=reset_agent, + ) + st.subheader("MCP Servers") mcp_selection = st.pills( label="Available MCP Servers", options=mcp_tools_list, selection_mode="multi", on_change=reset_agent @@ -67,6 +78,16 @@ def tool_chat_page(): on_change=reset_agent, ) + for i, tool_name in enumerate(toolgroup_selection): + if tool_name == "builtin::rag": + tool_dict = dict( + name="builtin::rag", + args={ + "vector_db_ids": list(selected_vector_dbs), + }, + ) + toolgroup_selection[i] = tool_dict + @st.cache_resource def create_agent(): return Agent(