diff --git a/llama_stack/ui/app/chat-playground/page.tsx b/llama_stack/ui/app/chat-playground/page.tsx index c31248b78..c900c57ab 100644 --- a/llama_stack/ui/app/chat-playground/page.tsx +++ b/llama_stack/ui/app/chat-playground/page.tsx @@ -15,6 +15,8 @@ import { type Message } from "@/components/chat-playground/chat-message"; import { useAuthClient } from "@/hooks/use-auth-client"; import type { CompletionCreateParams } from "llama-stack-client/resources/chat/completions"; import type { Model } from "llama-stack-client/resources/models"; +import type { VectorDBListResponse } from "llama-stack-client/resources/vector-dbs"; +import { VectorDbManager } from "@/components/vector-db/vector-db-manager"; export default function ChatPlaygroundPage() { const [messages, setMessages] = useState([]); @@ -25,6 +27,10 @@ export default function ChatPlaygroundPage() { const [selectedModel, setSelectedModel] = useState(""); const [modelsLoading, setModelsLoading] = useState(true); const [modelsError, setModelsError] = useState(null); + const [vectorDbs, setVectorDbs] = useState([]); + const [selectedVectorDb, setSelectedVectorDb] = useState(""); + const [vectorDbsLoading, setVectorDbsLoading] = useState(true); + const [vectorDbsError, setVectorDbsError] = useState(null); const client = useAuthClient(); const isModelsLoading = modelsLoading ?? true; @@ -49,7 +55,22 @@ export default function ChatPlaygroundPage() { } }; + const fetchVectorDbs = async () => { + try { + setVectorDbsLoading(true); + setVectorDbsError(null); + const vectorDbList = await client.vectorDBs.list(); + setVectorDbs(vectorDbList); + } catch (err) { + console.error("Error fetching vector DBs:", err); + setVectorDbsError("Failed to fetch available vector databases"); + } finally { + setVectorDbsLoading(false); + } + }; + fetchModels(); + fetchVectorDbs(); }, [client]); const extractTextContent = (content: unknown): string => { @@ -96,6 +117,35 @@ const handleSubmitWithContent = async (content: string) => { setError(null); try { + let enhancedContent = content; + + // If a vector DB is selected, query for relevant context + if (selectedVectorDb && selectedVectorDb !== "none") { + try { + const vectorResponse = await client.vectorIo.query({ + query: content, + vector_db_id: selectedVectorDb, + }); + + if (vectorResponse.chunks && vectorResponse.chunks.length > 0) { + const context = vectorResponse.chunks + .map(chunk => { + // Extract text content from the chunk + const chunkContent = typeof chunk.content === 'string' + ? chunk.content + : extractTextContent(chunk.content); + return chunkContent; + }) + .join('\n\n'); + + enhancedContent = `Please answer the following query using the context below.\n\nCONTEXT:\n${context}\n\nQUERY:\n${content}`; + } + } catch (vectorErr) { + console.error("Error querying vector DB:", vectorErr); + // Continue with original content if vector query fails + } + } + const messageParams: CompletionCreateParams["messages"] = [ ...messages.map(msg => { const msgContent = typeof msg.content === 'string' ? msg.content : extractTextContent(msg.content); @@ -107,7 +157,7 @@ const handleSubmitWithContent = async (content: string) => { return { role: "system" as const, content: msgContent }; } }), - { role: "user" as const, content } + { role: "user" as const, content: enhancedContent } ]; const response = await client.chat.completions.create({ @@ -172,6 +222,20 @@ const handleSubmitWithContent = async (content: string) => { setError(null); }; + const refreshVectorDbs = async () => { + try { + setVectorDbsLoading(true); + setVectorDbsError(null); + const vectorDbList = await client.vectorDBs.list(); + setVectorDbs(vectorDbList); + } catch (err) { + console.error("Error refreshing vector DBs:", err); + setVectorDbsError("Failed to refresh vector databases"); + } finally { + setVectorDbsLoading(false); + } + }; + return (
@@ -189,6 +253,23 @@ const handleSubmitWithContent = async (content: string) => { ))} + + @@ -201,6 +282,12 @@ const handleSubmitWithContent = async (content: string) => {
)} + {vectorDbsError && ( +
+

{vectorDbsError}

+
+ )} + {error && (

{error}