From 8374d4cefd3e59c4dad68ec2842622b5f4154fd5 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 12 Jul 2025 16:23:42 -0400 Subject: [PATCH 01/12] chore(github-deps): bump medyagh/setup-minikube from 0.0.19 to 0.0.20 (#2738) --- .github/workflows/integration-auth-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration-auth-tests.yml b/.github/workflows/integration-auth-tests.yml index 7822e4216..cf10e005c 100644 --- a/.github/workflows/integration-auth-tests.yml +++ b/.github/workflows/integration-auth-tests.yml @@ -35,7 +35,7 @@ jobs: - name: Install minikube if: ${{ matrix.auth-provider == 'kubernetes' }} - uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # v0.0.19 + uses: medyagh/setup-minikube@e3c7f79eb1e997eabccc536a6cf318a2b0fe19d9 # v0.0.20 - name: Start minikube if: ${{ matrix.auth-provider == 'oauth2_token' }} From 68e7978c8890fd0aec901e0871fb92061a0d8fa5 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Sat, 12 Jul 2025 19:53:54 -0400 Subject: [PATCH 02/12] chore: block network access from unit tests (#2732) # What does this PR do? this blocks network access for all `tests/unit/` tests. `tests/integration/` are untouched. it also introduces an `allow_network` marker to explicitly allow network access. ## Test Plan `./scripts/unit-tests.sh` --- pyproject.toml | 4 ++++ tests/unit/conftest.py | 11 +++++++++++ tests/unit/providers/inference/test_remote_vllm.py | 1 + tests/unit/rag/test_vector_store.py | 2 ++ uv.lock | 14 ++++++++++++++ 5 files changed, 32 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 9977d7372..2974ff996 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ dev = [ "pytest-cov", "pytest-html", "pytest-json-report", + "pytest-socket", # For blocking network access in unit tests "nbval", # For notebook testing "black", "ruff", @@ -344,3 +345,6 @@ classmethod-decorators = ["classmethod", "pydantic.field_validator"] [tool.pytest.ini_options] asyncio_mode = "auto" +markers = [ + "allow_network: Allow network access for specific unit tests", +] diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index aedac0386..b5eb1217d 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -4,6 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import pytest_socket + # We need to import the fixtures here so that pytest can find them # but ruff doesn't think they are used and removes the import. "noqa: F401" prevents them from being removed from .fixtures import cached_disk_dist_registry, disk_dist_registry, sqlite_kvstore # noqa: F401 + + +def pytest_runtest_setup(item): + """Setup for each test - check if network access should be allowed.""" + if "allow_network" in item.keywords: + pytest_socket.enable_socket() + else: + # Allowing Unix sockets is necessary for some tests that use local servers and mocks + pytest_socket.disable_socket(allow_unix_socket=True) diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index ca44cc95d..5c2ad03ab 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -393,6 +393,7 @@ async def test_process_vllm_chat_completion_stream_response_no_choices(): assert chunks[0].event.event_type.value == "start" +@pytest.mark.allow_network def test_chat_completion_doesnt_block_event_loop(caplog): loop = asyncio.new_event_loop() loop.set_debug(True) diff --git a/tests/unit/rag/test_vector_store.py b/tests/unit/rag/test_vector_store.py index dd36d3992..919f97ba7 100644 --- a/tests/unit/rag/test_vector_store.py +++ b/tests/unit/rag/test_vector_store.py @@ -123,6 +123,7 @@ class TestVectorStore: content = await content_from_doc(doc) assert content in DUMMY_PDF_TEXT_CHOICES + @pytest.mark.allow_network async def test_downloads_pdf_and_returns_content(self): # Using GitHub to host the PDF file url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf" @@ -135,6 +136,7 @@ class TestVectorStore: content = await content_from_doc(doc) assert content in DUMMY_PDF_TEXT_CHOICES + @pytest.mark.allow_network async def test_downloads_pdf_and_returns_content_with_url_object(self): # Using GitHub to host the PDF file url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf" diff --git a/uv.lock b/uv.lock index bca12fc51..83e502e7f 100644 --- a/uv.lock +++ b/uv.lock @@ -1324,6 +1324,7 @@ dev = [ { name = "pytest-cov" }, { name = "pytest-html" }, { name = "pytest-json-report" }, + { name = "pytest-socket" }, { name = "pytest-timeout" }, { name = "ruamel-yaml" }, { name = "ruff" }, @@ -1432,6 +1433,7 @@ dev = [ { name = "pytest-cov" }, { name = "pytest-html" }, { name = "pytest-json-report" }, + { name = "pytest-socket" }, { name = "pytest-timeout" }, { name = "ruamel-yaml" }, { name = "ruff" }, @@ -2545,6 +2547,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3e/43/7e7b2ec865caa92f67b8f0e9231a798d102724ca4c0e1f414316be1c1ef2/pytest_metadata-3.1.1-py3-none-any.whl", hash = "sha256:c8e0844db684ee1c798cfa38908d20d67d0463ecb6137c72e91f418558dd5f4b", size = 11428, upload-time = "2024-02-12T19:38:42.531Z" }, ] +[[package]] +name = "pytest-socket" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/05/ff/90c7e1e746baf3d62ce864c479fd53410b534818b9437413903596f81580/pytest_socket-0.7.0.tar.gz", hash = "sha256:71ab048cbbcb085c15a4423b73b619a8b35d6a307f46f78ea46be51b1b7e11b3", size = 12389, upload-time = "2024-01-28T20:17:23.177Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/19/58/5d14cb5cb59409e491ebe816c47bf81423cd03098ea92281336320ae5681/pytest_socket-0.7.0-py3-none-any.whl", hash = "sha256:7e0f4642177d55d317bbd58fc68c6bd9048d6eadb2d46a89307fa9221336ce45", size = 6754, upload-time = "2024-01-28T20:17:22.105Z" }, +] + [[package]] name = "pytest-timeout" version = "2.4.0" From 958fc92b1bc99ba8e57e0819696f74a7e09f45f0 Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Sun, 13 Jul 2025 04:03:55 -0400 Subject: [PATCH 03/12] feat: Add Vector stores UI (#2737) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? - Adds two pages to UI - Vector stores - Vector store detail view - Fixed darkmode navbar highlighting - Updated darkmode font color - Updated llama-stack-client package Screenshot 2025-07-12 at 11 34
35 PM Screenshot 2025-07-12 at 11 57
09 PM ## Test Plan --------- Signed-off-by: Francisco Javier Arceo --- .../ui/app/logs/vector-stores/[id]/page.tsx | 82 +++ .../ui/app/logs/vector-stores/layout.tsx | 16 + .../ui/app/logs/vector-stores/page.tsx | 121 +++++ .../ui/components/layout/app-sidebar.tsx | 16 +- .../ui/components/layout/detail-layout.tsx | 8 +- .../ui/components/ui/message-components.tsx | 4 +- .../vector-stores/vector-store-detail.tsx | 128 +++++ llama_stack/ui/package-lock.json | 474 +----------------- llama_stack/ui/package.json | 2 +- 9 files changed, 378 insertions(+), 473 deletions(-) create mode 100644 llama_stack/ui/app/logs/vector-stores/[id]/page.tsx create mode 100644 llama_stack/ui/app/logs/vector-stores/layout.tsx create mode 100644 llama_stack/ui/app/logs/vector-stores/page.tsx create mode 100644 llama_stack/ui/components/vector-stores/vector-store-detail.tsx diff --git a/llama_stack/ui/app/logs/vector-stores/[id]/page.tsx b/llama_stack/ui/app/logs/vector-stores/[id]/page.tsx new file mode 100644 index 000000000..f27c9d802 --- /dev/null +++ b/llama_stack/ui/app/logs/vector-stores/[id]/page.tsx @@ -0,0 +1,82 @@ +"use client"; + +import { useEffect, useState } from "react"; +import { useParams, useRouter } from "next/navigation"; +import { useAuthClient } from "@/hooks/use-auth-client"; +import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores"; +import type { VectorStoreFile } from "llama-stack-client/resources/vector-stores/files"; +import { VectorStoreDetailView } from "@/components/vector-stores/vector-store-detail"; + +export default function VectorStoreDetailPage() { + const params = useParams(); + const id = params.id as string; + const client = useAuthClient(); + const router = useRouter(); + + const [store, setStore] = useState(null); + const [files, setFiles] = useState([]); + const [isLoadingStore, setIsLoadingStore] = useState(true); + const [isLoadingFiles, setIsLoadingFiles] = useState(true); + const [errorStore, setErrorStore] = useState(null); + const [errorFiles, setErrorFiles] = useState(null); + + useEffect(() => { + if (!id) { + setErrorStore(new Error("Vector Store ID is missing.")); + setIsLoadingStore(false); + return; + } + const fetchStore = async () => { + setIsLoadingStore(true); + setErrorStore(null); + try { + const response = await client.vectorStores.retrieve(id); + setStore(response as VectorStore); + } catch (err) { + setErrorStore( + err instanceof Error + ? err + : new Error("Failed to load vector store."), + ); + } finally { + setIsLoadingStore(false); + } + }; + fetchStore(); + }, [id, client]); + + useEffect(() => { + if (!id) { + setErrorFiles(new Error("Vector Store ID is missing.")); + setIsLoadingFiles(false); + return; + } + const fetchFiles = async () => { + setIsLoadingFiles(true); + setErrorFiles(null); + try { + const result = await client.vectorStores.files.list(id as any); + setFiles((result as any).data); + } catch (err) { + setErrorFiles( + err instanceof Error ? err : new Error("Failed to load files."), + ); + } finally { + setIsLoadingFiles(false); + } + }; + fetchFiles(); + }, [id]); + + return ( + + ); +} diff --git a/llama_stack/ui/app/logs/vector-stores/layout.tsx b/llama_stack/ui/app/logs/vector-stores/layout.tsx new file mode 100644 index 000000000..9245f5486 --- /dev/null +++ b/llama_stack/ui/app/logs/vector-stores/layout.tsx @@ -0,0 +1,16 @@ +"use client"; + +import React from "react"; +import LogsLayout from "@/components/layout/logs-layout"; + +export default function VectorStoresLayout({ + children, +}: { + children: React.ReactNode; +}) { + return ( + + {children} + + ); +} diff --git a/llama_stack/ui/app/logs/vector-stores/page.tsx b/llama_stack/ui/app/logs/vector-stores/page.tsx new file mode 100644 index 000000000..29e1fabd6 --- /dev/null +++ b/llama_stack/ui/app/logs/vector-stores/page.tsx @@ -0,0 +1,121 @@ +"use client"; + +import React from "react"; +import { useAuthClient } from "@/hooks/use-auth-client"; +import type { + ListVectorStoresResponse, + VectorStore, +} from "llama-stack-client/resources/vector-stores/vector-stores"; +import { useRouter } from "next/navigation"; +import { usePagination } from "@/hooks/use-pagination"; +import { + Table, + TableBody, + TableCaption, + TableCell, + TableHead, + TableHeader, + TableRow, +} from "@/components/ui/table"; +import { Skeleton } from "@/components/ui/skeleton"; + +export default function VectorStoresPage() { + const client = useAuthClient(); + const router = useRouter(); + const { + data: stores, + status, + hasMore, + error, + loadMore, + } = usePagination({ + limit: 20, + order: "desc", + fetchFunction: async (client, params) => { + const response = await client.vectorStores.list({ + after: params.after, + limit: params.limit, + order: params.order, + } as any); + return response as ListVectorStoresResponse; + }, + errorMessagePrefix: "vector stores", + }); + + // Auto-load all pages for infinite scroll behavior (like Responses) + React.useEffect(() => { + if (status === "idle" && hasMore) { + loadMore(); + } + }, [status, hasMore, loadMore]); + + if (status === "loading") { + return ( +
+ + + +
+ ); + } + + if (status === "error") { + return
Error: {error?.message}
; + } + + if (!stores || stores.length === 0) { + return

No vector stores found.

; + } + + return ( +
+ + + + ID + Name + Created + Completed + Cancelled + Failed + In Progress + Total + Usage Bytes + Provider ID + Provider Vector DB ID + + + + {stores.map((store) => { + const fileCounts = store.file_counts; + const metadata = store.metadata || {}; + const providerId = metadata.provider_id ?? ""; + const providerDbId = metadata.provider_vector_db_id ?? ""; + + return ( + router.push(`/logs/vector-stores/${store.id}`)} + className="cursor-pointer hover:bg-muted/50" + > + {store.id} + {store.name} + + {new Date(store.created_at * 1000).toLocaleString()} + + {fileCounts.completed} + {fileCounts.cancelled} + {fileCounts.failed} + {fileCounts.in_progress} + {fileCounts.total} + {store.usage_bytes} + {providerId} + {providerDbId} + + ); + })} + +
+
+ ); +} diff --git a/llama_stack/ui/components/layout/app-sidebar.tsx b/llama_stack/ui/components/layout/app-sidebar.tsx index 1c53d6cc5..532e43dbd 100644 --- a/llama_stack/ui/components/layout/app-sidebar.tsx +++ b/llama_stack/ui/components/layout/app-sidebar.tsx @@ -1,6 +1,11 @@ "use client"; -import { MessageSquareText, MessagesSquare, MoveUpRight } from "lucide-react"; +import { + MessageSquareText, + MessagesSquare, + MoveUpRight, + Database, +} from "lucide-react"; import Link from "next/link"; import { usePathname } from "next/navigation"; import { cn } from "@/lib/utils"; @@ -28,6 +33,11 @@ const logItems = [ url: "/logs/responses", icon: MessagesSquare, }, + { + title: "Vector Stores", + url: "/logs/vector-stores", + icon: Database, + }, { title: "Documentation", url: "https://llama-stack.readthedocs.io/en/latest/references/api_reference/index.html", @@ -57,13 +67,13 @@ export function AppSidebar() { className={cn( "justify-start", isActive && - "bg-gray-200 hover:bg-gray-200 text-primary hover:text-primary", + "bg-gray-200 dark:bg-gray-700 hover:bg-gray-200 dark:hover:bg-gray-700 text-gray-900 dark:text-gray-100", )} > diff --git a/llama_stack/ui/components/layout/detail-layout.tsx b/llama_stack/ui/components/layout/detail-layout.tsx index 58b912703..3013195a2 100644 --- a/llama_stack/ui/components/layout/detail-layout.tsx +++ b/llama_stack/ui/components/layout/detail-layout.tsx @@ -93,7 +93,9 @@ export function PropertyItem({ > {label}:{" "} {typeof value === "string" || typeof value === "number" ? ( - {value} + + {value} + ) : ( value )} @@ -112,7 +114,9 @@ export function PropertiesCard({ children }: PropertiesCardProps) { Properties -
    {children}
+
    + {children} +
); diff --git a/llama_stack/ui/components/ui/message-components.tsx b/llama_stack/ui/components/ui/message-components.tsx index 50ccd623e..39cb570b7 100644 --- a/llama_stack/ui/components/ui/message-components.tsx +++ b/llama_stack/ui/components/ui/message-components.tsx @@ -17,10 +17,10 @@ export const MessageBlock: React.FC = ({ }) => { return (
-

+

{label} {labelDetail && ( - + {labelDetail} )} diff --git a/llama_stack/ui/components/vector-stores/vector-store-detail.tsx b/llama_stack/ui/components/vector-stores/vector-store-detail.tsx new file mode 100644 index 000000000..7c5c91dd3 --- /dev/null +++ b/llama_stack/ui/components/vector-stores/vector-store-detail.tsx @@ -0,0 +1,128 @@ +"use client"; + +import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores"; +import type { VectorStoreFile } from "llama-stack-client/resources/vector-stores/files"; +import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card"; +import { Skeleton } from "@/components/ui/skeleton"; +import { + DetailLoadingView, + DetailErrorView, + DetailNotFoundView, + DetailLayout, + PropertiesCard, + PropertyItem, +} from "@/components/layout/detail-layout"; +import { + Table, + TableBody, + TableCaption, + TableCell, + TableHead, + TableHeader, + TableRow, +} from "@/components/ui/table"; + +interface VectorStoreDetailViewProps { + store: VectorStore | null; + files: VectorStoreFile[]; + isLoadingStore: boolean; + isLoadingFiles: boolean; + errorStore: Error | null; + errorFiles: Error | null; + id: string; +} + +export function VectorStoreDetailView({ + store, + files, + isLoadingStore, + isLoadingFiles, + errorStore, + errorFiles, + id, +}: VectorStoreDetailViewProps) { + const title = "Vector Store Details"; + + if (errorStore) { + return ; + } + if (isLoadingStore) { + return ; + } + if (!store) { + return ; + } + + const mainContent = ( + <> + + + Files + + + {isLoadingFiles ? ( + + ) : errorFiles ? ( +

+ Error loading files: {errorFiles.message} +
+ ) : files.length > 0 ? ( + + Files in this vector store + + + ID + Status + Created + Usage Bytes + + + + {files.map((file) => ( + + {file.id} + {file.status} + + {new Date(file.created_at * 1000).toLocaleString()} + + {file.usage_bytes} + + ))} + +
+ ) : ( +

+ No files in this vector store. +

+ )} + + + + ); + + const sidebar = ( + + + + + + + + + + + ); + + return ( + + ); +} diff --git a/llama_stack/ui/package-lock.json b/llama_stack/ui/package-lock.json index 8fd5fb56c..158569241 100644 --- a/llama_stack/ui/package-lock.json +++ b/llama_stack/ui/package-lock.json @@ -15,7 +15,7 @@ "@radix-ui/react-tooltip": "^1.2.6", "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", - "llama-stack-client": "0.2.13", + "llama-stack-client": "^0.2.14", "lucide-react": "^0.510.0", "next": "15.3.3", "next-auth": "^4.24.11", @@ -676,406 +676,6 @@ "tslib": "^2.4.0" } }, - "node_modules/@esbuild/aix-ppc64": { - "version": "0.25.5", - "resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.25.5.tgz", - "integrity": "sha512-9o3TMmpmftaCMepOdA5k/yDw8SfInyzWWTjYTFCX3kPSDJMROQTb8jg+h9Cnwnmm1vOzvxN7gIfB5V2ewpjtGA==", - "cpu": [ - "ppc64" - ], - "license": "MIT", - "optional": true, - "os": [ - "aix" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/@esbuild/android-arm": { - "version": "0.25.5", - "resolved": "https://registry.npmjs.org/@esbuild/android-arm/-/android-arm-0.25.5.tgz", - "integrity": "sha512-AdJKSPeEHgi7/ZhuIPtcQKr5RQdo6OO2IL87JkianiMYMPbCtot9fxPbrMiBADOWWm3T2si9stAiVsGbTQFkbA==", - "cpu": [ - "arm" - ], - "license": "MIT", - "optional": true, - "os": [ - "android" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/@esbuild/android-arm64": { - "version": "0.25.5", - "resolved": "https://registry.npmjs.org/@esbuild/android-arm64/-/android-arm64-0.25.5.tgz", - "integrity": "sha512-VGzGhj4lJO+TVGV1v8ntCZWJktV7SGCs3Pn1GRWI1SBFtRALoomm8k5E9Pmwg3HOAal2VDc2F9+PM/rEY6oIDg==", - "cpu": [ - "arm64" - ], - "license": "MIT", - "optional": true, - "os": [ - "android" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/@esbuild/android-x64": { - "version": "0.25.5", - "resolved": "https://registry.npmjs.org/@esbuild/android-x64/-/android-x64-0.25.5.tgz", - "integrity": "sha512-D2GyJT1kjvO//drbRT3Hib9XPwQeWd9vZoBJn+bu/lVsOZ13cqNdDeqIF/xQ5/VmWvMduP6AmXvylO/PIc2isw==", - "cpu": [ - "x64" - ], - "license": "MIT", - "optional": true, - "os": [ - "android" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/@esbuild/darwin-arm64": { - "version": "0.25.5", - "resolved": "https://registry.npmjs.org/@esbuild/darwin-arm64/-/darwin-arm64-0.25.5.tgz", - "integrity": "sha512-GtaBgammVvdF7aPIgH2jxMDdivezgFu6iKpmT+48+F8Hhg5J/sfnDieg0aeG/jfSvkYQU2/pceFPDKlqZzwnfQ==", - "cpu": [ - "arm64" - ], - "license": "MIT", - "optional": true, - "os": [ - "darwin" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/@esbuild/darwin-x64": { - "version": "0.25.5", - "resolved": "https://registry.npmjs.org/@esbuild/darwin-x64/-/darwin-x64-0.25.5.tgz", - "integrity": "sha512-1iT4FVL0dJ76/q1wd7XDsXrSW+oLoquptvh4CLR4kITDtqi2e/xwXwdCVH8hVHU43wgJdsq7Gxuzcs6Iq/7bxQ==", - "cpu": [ - "x64" - ], - "license": "MIT", - "optional": true, - "os": [ - "darwin" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/@esbuild/freebsd-arm64": { - "version": "0.25.5", - "resolved": "https://registry.npmjs.org/@esbuild/freebsd-arm64/-/freebsd-arm64-0.25.5.tgz", - "integrity": "sha512-nk4tGP3JThz4La38Uy/gzyXtpkPW8zSAmoUhK9xKKXdBCzKODMc2adkB2+8om9BDYugz+uGV7sLmpTYzvmz6Sw==", - "cpu": [ - "arm64" - ], - "license": "MIT", - "optional": true, - "os": [ - "freebsd" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/@esbuild/freebsd-x64": { - "version": "0.25.5", - "resolved": "https://registry.npmjs.org/@esbuild/freebsd-x64/-/freebsd-x64-0.25.5.tgz", - "integrity": "sha512-PrikaNjiXdR2laW6OIjlbeuCPrPaAl0IwPIaRv+SMV8CiM8i2LqVUHFC1+8eORgWyY7yhQY+2U2fA55mBzReaw==", - "cpu": [ - "x64" - ], - "license": "MIT", - "optional": true, - "os": [ - "freebsd" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/@esbuild/linux-arm": { - "version": "0.25.5", - "resolved": "https://registry.npmjs.org/@esbuild/linux-arm/-/linux-arm-0.25.5.tgz", - "integrity": "sha512-cPzojwW2okgh7ZlRpcBEtsX7WBuqbLrNXqLU89GxWbNt6uIg78ET82qifUy3W6OVww6ZWobWub5oqZOVtwolfw==", - "cpu": [ - "arm" - ], - "license": "MIT", - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/@esbuild/linux-arm64": { - "version": "0.25.5", - "resolved": "https://registry.npmjs.org/@esbuild/linux-arm64/-/linux-arm64-0.25.5.tgz", - "integrity": "sha512-Z9kfb1v6ZlGbWj8EJk9T6czVEjjq2ntSYLY2cw6pAZl4oKtfgQuS4HOq41M/BcoLPzrUbNd+R4BXFyH//nHxVg==", - "cpu": [ - "arm64" - ], - "license": "MIT", - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/@esbuild/linux-ia32": { - "version": "0.25.5", - "resolved": "https://registry.npmjs.org/@esbuild/linux-ia32/-/linux-ia32-0.25.5.tgz", - "integrity": "sha512-sQ7l00M8bSv36GLV95BVAdhJ2QsIbCuCjh/uYrWiMQSUuV+LpXwIqhgJDcvMTj+VsQmqAHL2yYaasENvJ7CDKA==", - "cpu": [ - "ia32" - ], - "license": "MIT", - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/@esbuild/linux-loong64": { - "version": "0.25.5", - "resolved": "https://registry.npmjs.org/@esbuild/linux-loong64/-/linux-loong64-0.25.5.tgz", - "integrity": "sha512-0ur7ae16hDUC4OL5iEnDb0tZHDxYmuQyhKhsPBV8f99f6Z9KQM02g33f93rNH5A30agMS46u2HP6qTdEt6Q1kg==", - "cpu": [ - "loong64" - ], - "license": "MIT", - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/@esbuild/linux-mips64el": { - "version": "0.25.5", - "resolved": "https://registry.npmjs.org/@esbuild/linux-mips64el/-/linux-mips64el-0.25.5.tgz", - "integrity": "sha512-kB/66P1OsHO5zLz0i6X0RxlQ+3cu0mkxS3TKFvkb5lin6uwZ/ttOkP3Z8lfR9mJOBk14ZwZ9182SIIWFGNmqmg==", - "cpu": [ - "mips64el" - ], - "license": "MIT", - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/@esbuild/linux-ppc64": { - "version": "0.25.5", - "resolved": "https://registry.npmjs.org/@esbuild/linux-ppc64/-/linux-ppc64-0.25.5.tgz", - "integrity": "sha512-UZCmJ7r9X2fe2D6jBmkLBMQetXPXIsZjQJCjgwpVDz+YMcS6oFR27alkgGv3Oqkv07bxdvw7fyB71/olceJhkQ==", - "cpu": [ - "ppc64" - ], - "license": "MIT", - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/@esbuild/linux-riscv64": { - "version": "0.25.5", - "resolved": "https://registry.npmjs.org/@esbuild/linux-riscv64/-/linux-riscv64-0.25.5.tgz", - "integrity": "sha512-kTxwu4mLyeOlsVIFPfQo+fQJAV9mh24xL+y+Bm6ej067sYANjyEw1dNHmvoqxJUCMnkBdKpvOn0Ahql6+4VyeA==", - "cpu": [ - "riscv64" - ], - "license": "MIT", - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/@esbuild/linux-s390x": { - "version": "0.25.5", - "resolved": "https://registry.npmjs.org/@esbuild/linux-s390x/-/linux-s390x-0.25.5.tgz", - "integrity": "sha512-K2dSKTKfmdh78uJ3NcWFiqyRrimfdinS5ErLSn3vluHNeHVnBAFWC8a4X5N+7FgVE1EjXS1QDZbpqZBjfrqMTQ==", - "cpu": [ - "s390x" - ], - "license": "MIT", - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/@esbuild/linux-x64": { - "version": "0.25.5", - "resolved": "https://registry.npmjs.org/@esbuild/linux-x64/-/linux-x64-0.25.5.tgz", - "integrity": "sha512-uhj8N2obKTE6pSZ+aMUbqq+1nXxNjZIIjCjGLfsWvVpy7gKCOL6rsY1MhRh9zLtUtAI7vpgLMK6DxjO8Qm9lJw==", - "cpu": [ - "x64" - ], - "license": "MIT", - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/@esbuild/netbsd-arm64": { - "version": "0.25.5", - "resolved": "https://registry.npmjs.org/@esbuild/netbsd-arm64/-/netbsd-arm64-0.25.5.tgz", - "integrity": "sha512-pwHtMP9viAy1oHPvgxtOv+OkduK5ugofNTVDilIzBLpoWAM16r7b/mxBvfpuQDpRQFMfuVr5aLcn4yveGvBZvw==", - "cpu": [ - "arm64" - ], - "license": "MIT", - "optional": true, - "os": [ - "netbsd" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/@esbuild/netbsd-x64": { - "version": "0.25.5", - "resolved": "https://registry.npmjs.org/@esbuild/netbsd-x64/-/netbsd-x64-0.25.5.tgz", - "integrity": "sha512-WOb5fKrvVTRMfWFNCroYWWklbnXH0Q5rZppjq0vQIdlsQKuw6mdSihwSo4RV/YdQ5UCKKvBy7/0ZZYLBZKIbwQ==", - "cpu": [ - "x64" - ], - "license": "MIT", - "optional": true, - "os": [ - "netbsd" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/@esbuild/openbsd-arm64": { - "version": "0.25.5", - "resolved": "https://registry.npmjs.org/@esbuild/openbsd-arm64/-/openbsd-arm64-0.25.5.tgz", - "integrity": "sha512-7A208+uQKgTxHd0G0uqZO8UjK2R0DDb4fDmERtARjSHWxqMTye4Erz4zZafx7Di9Cv+lNHYuncAkiGFySoD+Mw==", - "cpu": [ - "arm64" - ], - "license": "MIT", - "optional": true, - "os": [ - "openbsd" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/@esbuild/openbsd-x64": { - "version": "0.25.5", - "resolved": "https://registry.npmjs.org/@esbuild/openbsd-x64/-/openbsd-x64-0.25.5.tgz", - "integrity": "sha512-G4hE405ErTWraiZ8UiSoesH8DaCsMm0Cay4fsFWOOUcz8b8rC6uCvnagr+gnioEjWn0wC+o1/TAHt+It+MpIMg==", - "cpu": [ - "x64" - ], - "license": "MIT", - "optional": true, - "os": [ - "openbsd" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/@esbuild/sunos-x64": { - "version": "0.25.5", - "resolved": "https://registry.npmjs.org/@esbuild/sunos-x64/-/sunos-x64-0.25.5.tgz", - "integrity": "sha512-l+azKShMy7FxzY0Rj4RCt5VD/q8mG/e+mDivgspo+yL8zW7qEwctQ6YqKX34DTEleFAvCIUviCFX1SDZRSyMQA==", - "cpu": [ - "x64" - ], - "license": "MIT", - "optional": true, - "os": [ - "sunos" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/@esbuild/win32-arm64": { - "version": "0.25.5", - "resolved": "https://registry.npmjs.org/@esbuild/win32-arm64/-/win32-arm64-0.25.5.tgz", - "integrity": "sha512-O2S7SNZzdcFG7eFKgvwUEZ2VG9D/sn/eIiz8XRZ1Q/DO5a3s76Xv0mdBzVM5j5R639lXQmPmSo0iRpHqUUrsxw==", - "cpu": [ - "arm64" - ], - "license": "MIT", - "optional": true, - "os": [ - "win32" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/@esbuild/win32-ia32": { - "version": "0.25.5", - "resolved": "https://registry.npmjs.org/@esbuild/win32-ia32/-/win32-ia32-0.25.5.tgz", - "integrity": "sha512-onOJ02pqs9h1iMJ1PQphR+VZv8qBMQ77Klcsqv9CNW2w6yLqoURLcgERAIurY6QE63bbLuqgP9ATqajFLK5AMQ==", - "cpu": [ - "ia32" - ], - "license": "MIT", - "optional": true, - "os": [ - "win32" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/@esbuild/win32-x64": { - "version": "0.25.5", - "resolved": "https://registry.npmjs.org/@esbuild/win32-x64/-/win32-x64-0.25.5.tgz", - "integrity": "sha512-TXv6YnJ8ZMVdX+SXWVBo/0p8LTcrUYngpWjvm91TMjjBQii7Oz11Lw5lbDV5Y0TzuhSJHwiH4hEtC1I42mMS0g==", - "cpu": [ - "x64" - ], - "license": "MIT", - "optional": true, - "os": [ - "win32" - ], - "engines": { - "node": ">=18" - } - }, "node_modules/@eslint-community/eslint-utils": { "version": "4.7.0", "resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.7.0.tgz", @@ -5999,46 +5599,6 @@ "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/esbuild": { - "version": "0.25.5", - "resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.25.5.tgz", - "integrity": "sha512-P8OtKZRv/5J5hhz0cUAdu/cLuPIKXpQl1R9pZtvmHWQvrAUVd0UNIPT4IB4W3rNOqVO0rlqHmCIbSwxh/c9yUQ==", - "hasInstallScript": true, - "license": "MIT", - "bin": { - "esbuild": "bin/esbuild" - }, - "engines": { - "node": ">=18" - }, - "optionalDependencies": { - "@esbuild/aix-ppc64": "0.25.5", - "@esbuild/android-arm": "0.25.5", - "@esbuild/android-arm64": "0.25.5", - "@esbuild/android-x64": "0.25.5", - "@esbuild/darwin-arm64": "0.25.5", - "@esbuild/darwin-x64": "0.25.5", - "@esbuild/freebsd-arm64": "0.25.5", - "@esbuild/freebsd-x64": "0.25.5", - "@esbuild/linux-arm": "0.25.5", - "@esbuild/linux-arm64": "0.25.5", - "@esbuild/linux-ia32": "0.25.5", - "@esbuild/linux-loong64": "0.25.5", - "@esbuild/linux-mips64el": "0.25.5", - "@esbuild/linux-ppc64": "0.25.5", - "@esbuild/linux-riscv64": "0.25.5", - "@esbuild/linux-s390x": "0.25.5", - "@esbuild/linux-x64": "0.25.5", - "@esbuild/netbsd-arm64": "0.25.5", - "@esbuild/netbsd-x64": "0.25.5", - "@esbuild/openbsd-arm64": "0.25.5", - "@esbuild/openbsd-x64": "0.25.5", - "@esbuild/sunos-x64": "0.25.5", - "@esbuild/win32-arm64": "0.25.5", - "@esbuild/win32-ia32": "0.25.5", - "@esbuild/win32-x64": "0.25.5" - } - }, "node_modules/escalade": { "version": "3.2.0", "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz", @@ -6993,6 +6553,7 @@ "version": "2.3.3", "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==", + "dev": true, "hasInstallScript": true, "license": "MIT", "optional": true, @@ -7154,6 +6715,7 @@ "version": "4.10.0", "resolved": "https://registry.npmjs.org/get-tsconfig/-/get-tsconfig-4.10.0.tgz", "integrity": "sha512-kGzZ3LWWQcGIAmg6iWvXn0ei6WDtV26wzHRMwDSzmAbcXrTEXxHy6IehI6/4eT6VRKyMP1eF1VqwrVUmE/LR7A==", + "dev": true, "license": "MIT", "dependencies": { "resolve-pkg-maps": "^1.0.0" @@ -9537,9 +9099,10 @@ "license": "MIT" }, "node_modules/llama-stack-client": { - "version": "0.2.13", - "resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.13.tgz", - "integrity": "sha512-R1rTFLwgUimr+KjEUkzUvFL6vLASwS9qj3UDSVkJ5BmrKAs5GwVAMeL7yZaTBXGuPUVh124WSlC4d9H0FjWqLA==", + "version": "0.2.14", + "resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.14.tgz", + "integrity": "sha512-bVU3JHp+EPEKR0Vb9vcd9ZyQj/72jSDuptKLwOXET9WrkphIQ8xuW5ueecMTgq8UEls3lwB3HiZM2cDOR9eDsQ==", + "license": "Apache-2.0", "dependencies": { "@types/node": "^18.11.18", "@types/node-fetch": "^2.6.4", @@ -9547,8 +9110,7 @@ "agentkeepalive": "^4.2.1", "form-data-encoder": "1.7.2", "formdata-node": "^4.3.2", - "node-fetch": "^2.6.7", - "tsx": "^4.19.2" + "node-fetch": "^2.6.7" } }, "node_modules/llama-stack-client/node_modules/@types/node": { @@ -11148,6 +10710,7 @@ "version": "1.0.0", "resolved": "https://registry.npmjs.org/resolve-pkg-maps/-/resolve-pkg-maps-1.0.0.tgz", "integrity": "sha512-seS2Tj26TBVOC2NIc2rOe2y2ZO7efxITtLZcGSOnHHNOQ7CkiUBfw0Iw2ck6xkIhPwLhKNLS8BO+hEpngQlqzw==", + "dev": true, "license": "MIT", "funding": { "url": "https://github.com/privatenumber/resolve-pkg-maps?sponsor=1" @@ -12198,25 +11761,6 @@ "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", "license": "0BSD" }, - "node_modules/tsx": { - "version": "4.19.4", - "resolved": "https://registry.npmjs.org/tsx/-/tsx-4.19.4.tgz", - "integrity": "sha512-gK5GVzDkJK1SI1zwHf32Mqxf2tSJkNx+eYcNly5+nHvWqXUJYUkWBQtKauoESz3ymezAI++ZwT855x5p5eop+Q==", - "license": "MIT", - "dependencies": { - "esbuild": "~0.25.0", - "get-tsconfig": "^4.7.5" - }, - "bin": { - "tsx": "dist/cli.mjs" - }, - "engines": { - "node": ">=18.0.0" - }, - "optionalDependencies": { - "fsevents": "~2.3.3" - } - }, "node_modules/tw-animate-css": { "version": "1.2.9", "resolved": "https://registry.npmjs.org/tw-animate-css/-/tw-animate-css-1.2.9.tgz", diff --git a/llama_stack/ui/package.json b/llama_stack/ui/package.json index 9524ce0a5..b38efe309 100644 --- a/llama_stack/ui/package.json +++ b/llama_stack/ui/package.json @@ -20,7 +20,7 @@ "@radix-ui/react-tooltip": "^1.2.6", "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", - "llama-stack-client": "0.2.13", + "llama-stack-client": "^0.2.14", "lucide-react": "^0.510.0", "next": "15.3.3", "next-auth": "^4.24.11", From 618ccea0908bc03914f71fdb6aca5062c8f01699 Mon Sep 17 00:00:00 2001 From: Mark Campbell Date: Mon, 14 Jul 2025 14:11:34 +0100 Subject: [PATCH 04/12] feat: add input validation for search mode of rag query config (#2275) # What does this PR do? Adds input validation for mode in RagQueryConfig This will prevent users from inputting search modes other than `vector` and `keyword` for the time being with `hybrid` to follow when that functionality is implemented. ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] ``` # Check out this PR and enter the LS directory uv sync --extra dev ``` Run the quickstart [example](https://llama-stack.readthedocs.io/en/latest/getting_started/#step-3-run-the-demo) Alter the Agent to include a query_config ``` agent = Agent( client, model=model_id, instructions="You are a helpful assistant", tools=[ { "name": "builtin::rag/knowledge_search", "args": { "vector_db_ids": [vector_db_id], "query_config": { "mode": "i-am-not-vector", # Test for non valid search mode "max_chunks": 6 } }, } ], ) ``` Ensure you get the following error: ``` 400: {'errors': [{'loc': ['mode'], 'msg': "Value error, mode must be either 'vector' or 'keyword' if supported by the vector_io provider", 'type': 'value_error'}]} ``` ## Running unit tests ``` uv sync --extra dev uv run pytest tests/unit/rag/test_rag_query.py -v ``` [//]: # (## Documentation) --- docs/_static/llama-stack-spec.html | 13 ++++++++++++- docs/_static/llama-stack-spec.yaml | 14 +++++++++++++- llama_stack/apis/tools/rag_tool.py | 16 +++++++++++++++- tests/unit/rag/test_rag_query.py | 12 ++++++++++++ 4 files changed, 52 insertions(+), 3 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 8021e0e55..6794d1fbb 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -14796,7 +14796,8 @@ "description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\"" }, "mode": { - "type": "string", + "$ref": "#/components/schemas/RAGSearchMode", + "default": "vector", "description": "Search mode for retrieval—either \"vector\", \"keyword\", or \"hybrid\". Default \"vector\"." }, "ranker": { @@ -14831,6 +14832,16 @@ } } }, + "RAGSearchMode": { + "type": "string", + "enum": [ + "vector", + "keyword", + "hybrid" + ], + "title": "RAGSearchMode", + "description": "Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search for semantic matching - KEYWORD: Uses keyword-based search for exact matching - HYBRID: Combines both vector and keyword search for better results" + }, "RRFRanker": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index a18474646..548c5a988 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -10346,7 +10346,8 @@ components: content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n" mode: - type: string + $ref: '#/components/schemas/RAGSearchMode' + default: vector description: >- Search mode for retrieval—either "vector", "keyword", or "hybrid". Default "vector". @@ -10373,6 +10374,17 @@ components: mapping: default: '#/components/schemas/DefaultRAGQueryGeneratorConfig' llm: '#/components/schemas/LLMRAGQueryGeneratorConfig' + RAGSearchMode: + type: string + enum: + - vector + - keyword + - hybrid + title: RAGSearchMode + description: >- + Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search + for semantic matching - KEYWORD: Uses keyword-based search for exact matching + - HYBRID: Combines both vector and keyword search for better results RRFRanker: type: object properties: diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py index d497fe1a7..cfaa49488 100644 --- a/llama_stack/apis/tools/rag_tool.py +++ b/llama_stack/apis/tools/rag_tool.py @@ -87,6 +87,20 @@ class RAGQueryGenerator(Enum): custom = "custom" +@json_schema_type +class RAGSearchMode(Enum): + """ + Search modes for RAG query retrieval: + - VECTOR: Uses vector similarity search for semantic matching + - KEYWORD: Uses keyword-based search for exact matching + - HYBRID: Combines both vector and keyword search for better results + """ + + VECTOR = "vector" + KEYWORD = "keyword" + HYBRID = "hybrid" + + @json_schema_type class DefaultRAGQueryGeneratorConfig(BaseModel): type: Literal["default"] = "default" @@ -128,7 +142,7 @@ class RAGQueryConfig(BaseModel): max_tokens_in_context: int = 4096 max_chunks: int = 5 chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n" - mode: str | None = None + mode: RAGSearchMode | None = RAGSearchMode.VECTOR ranker: Ranker | None = Field(default=None) # Only used for hybrid mode @field_validator("chunk_template") diff --git a/tests/unit/rag/test_rag_query.py b/tests/unit/rag/test_rag_query.py index b2baa744a..ad155c205 100644 --- a/tests/unit/rag/test_rag_query.py +++ b/tests/unit/rag/test_rag_query.py @@ -8,6 +8,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from llama_stack.apis.tools.rag_tool import RAGQueryConfig from llama_stack.apis.vector_io import ( Chunk, ChunkMetadata, @@ -58,3 +59,14 @@ class TestRagQuery: ) assert expected_metadata_string in result.content[1].text assert result.content is not None + + async def test_query_raises_incorrect_mode(self): + with pytest.raises(ValueError): + RAGQueryConfig(mode="invalid_mode") + + @pytest.mark.asyncio + async def test_query_accepts_valid_modes(self): + RAGQueryConfig() # Test default (vector) + RAGQueryConfig(mode="vector") # Test vector + RAGQueryConfig(mode="keyword") # Test keyword + RAGQueryConfig(mode="hybrid") # Test hybrid From 77d2c8e95d95625e506d7ae06105e213b69351d9 Mon Sep 17 00:00:00 2001 From: Sumanth Kamenani Date: Mon, 14 Jul 2025 12:53:13 -0400 Subject: [PATCH 05/12] docs: clarify run.yaml files are starting points for customization (#2746) # What does this PR do? This PR improves documentation clarity around run.yaml file usage. It adds comprehensive guidance to help users understand that generated run.yaml files are templates meant to be customized for production use, not used as-is. ## Changes - Add new documentation section on customizing run.yaml files - Clarify that generated run.yaml files are templates, not production configs - Add guidance on customization best practices and common scenarios - Update existing documentation to reference customization guide - Improve clarity around run.yaml file usage for better user experience ## Test Plan - Verified new documentation file exists at correct location - Confirmed documentation is properly integrated into the toctree structure - Checked all internal links use correct paths and reference existing files - Validated references are added to relevant existing documentation files - Documentation build testing will be handled by CI environment --- docs/source/distributions/building_distro.md | 4 ++ docs/source/distributions/configuration.md | 4 ++ .../distributions/customizing_run_yaml.md | 40 +++++++++++++++++++ docs/source/distributions/index.md | 1 + .../getting_started/detailed_tutorial.md | 2 +- 5 files changed, 50 insertions(+), 1 deletion(-) create mode 100644 docs/source/distributions/customizing_run_yaml.md diff --git a/docs/source/distributions/building_distro.md b/docs/source/distributions/building_distro.md index f24974dd3..cd2c6b6a8 100644 --- a/docs/source/distributions/building_distro.md +++ b/docs/source/distributions/building_distro.md @@ -145,6 +145,10 @@ $ llama stack build --template starter ... You can now edit ~/.llama/distributions/llamastack-starter/starter-run.yaml and run `llama stack run ~/.llama/distributions/llamastack-starter/starter-run.yaml` ``` + +```{tip} +The generated `run.yaml` file is a starting point for your configuration. For comprehensive guidance on customizing it for your specific needs, infrastructure, and deployment scenarios, see [Customizing Your run.yaml Configuration](customizing_run_yaml.md). +``` ::: :::{tab-item} Building from Scratch diff --git a/docs/source/distributions/configuration.md b/docs/source/distributions/configuration.md index 4709cb8c6..9548780c6 100644 --- a/docs/source/distributions/configuration.md +++ b/docs/source/distributions/configuration.md @@ -2,6 +2,10 @@ The Llama Stack runtime configuration is specified as a YAML file. Here is a simplified version of an example configuration file for the Ollama distribution: +```{note} +The default `run.yaml` files generated by templates are starting points for your configuration. For guidance on customizing these files for your specific needs, see [Customizing Your run.yaml Configuration](customizing_run_yaml.md). +``` + ```{dropdown} 👋 Click here for a Sample Configuration File ```yaml diff --git a/docs/source/distributions/customizing_run_yaml.md b/docs/source/distributions/customizing_run_yaml.md new file mode 100644 index 000000000..10067bab7 --- /dev/null +++ b/docs/source/distributions/customizing_run_yaml.md @@ -0,0 +1,40 @@ +# Customizing run.yaml Files + +The `run.yaml` files generated by Llama Stack templates are **starting points** designed to be customized for your specific needs. They are not meant to be used as-is in production environments. + +## Key Points + +- **Templates are starting points**: Generated `run.yaml` files contain defaults for development/testing +- **Customization expected**: Update URLs, credentials, models, and settings for your environment +- **Version control separately**: Keep customized configs in your own repository +- **Environment-specific**: Create different configurations for dev, staging, production + +## What You Can Customize + +You can customize: +- **Provider endpoints**: Change `http://localhost:8000` to your actual servers +- **Swap providers**: Replace default providers (e.g., swap Tavily with Brave for search) +- **Storage paths**: Move from `/tmp/` to production directories +- **Authentication**: Add API keys, SSL, timeouts +- **Models**: Different model sizes for dev vs prod +- **Database settings**: Switch from SQLite to PostgreSQL +- **Tool configurations**: Add custom tools and integrations + +## Best Practices + +- Use environment variables for secrets and environment-specific values +- Create separate `run.yaml` files for different environments (dev, staging, prod) +- Document your changes with comments +- Test configurations before deployment +- Keep your customized configs in version control + +Example structure: +``` +your-project/ +├── configs/ +│ ├── dev-run.yaml +│ ├── prod-run.yaml +└── README.md +``` + +The goal is to take the generated template and adapt it to your specific infrastructure and operational needs. \ No newline at end of file diff --git a/docs/source/distributions/index.md b/docs/source/distributions/index.md index 103a6131f..600eec3a1 100644 --- a/docs/source/distributions/index.md +++ b/docs/source/distributions/index.md @@ -9,6 +9,7 @@ This section provides an overview of the distributions available in Llama Stack. importing_as_library configuration +customizing_run_yaml list_of_distributions kubernetes_deployment building_distro diff --git a/docs/source/getting_started/detailed_tutorial.md b/docs/source/getting_started/detailed_tutorial.md index 35cb7f02e..97e7df774 100644 --- a/docs/source/getting_started/detailed_tutorial.md +++ b/docs/source/getting_started/detailed_tutorial.md @@ -54,7 +54,7 @@ Llama Stack is a server that exposes multiple APIs, you connect with it using th You can use Python to build and run the Llama Stack server, which is useful for testing and development. Llama Stack uses a [YAML configuration file](../distributions/configuration.md) to specify the stack setup, -which defines the providers and their settings. +which defines the providers and their settings. The generated configuration serves as a starting point that you can [customize for your specific needs](../distributions/customizing_run_yaml.md). Now let's build and run the Llama Stack config for Ollama. We use `starter` as template. By default all providers are disabled, this requires enable ollama by passing environment variables. From a7ed86181c8b823ca13af0c911be694e49899be9 Mon Sep 17 00:00:00 2001 From: Derek Higgins Date: Mon, 14 Jul 2025 18:58:23 +0100 Subject: [PATCH 06/12] fix(faiss): Delete file contents from kvstore (#2686) Remove both the metadata and content from the kvstore when a file is being removed from the vector store. Closes: #2685 Also add faiss provider to openai_vector_stores test suite --------- Signed-off-by: Derek Higgins Co-authored-by: raghotham --- .../providers/inline/vector_io/faiss/faiss.py | 21 +++++++-- tests/unit/providers/vector_io/conftest.py | 44 ++++++++++++++++++- .../test_vector_io_openai_vector_stores.py | 6 ++- 3 files changed, 63 insertions(+), 8 deletions(-) diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index 62a98413d..0306d9156 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -267,6 +267,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr assert self.kvstore is not None key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" await self.kvstore.set(key=key, value=json.dumps(store_info)) + self.openai_vector_stores[store_id] = store_info async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]: """Load all vector store metadata from kvstore.""" @@ -286,17 +287,20 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr assert self.kvstore is not None key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" await self.kvstore.set(key=key, value=json.dumps(store_info)) + self.openai_vector_stores[store_id] = store_info async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None: """Delete vector store metadata from kvstore.""" assert self.kvstore is not None key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" await self.kvstore.delete(key) + if store_id in self.openai_vector_stores: + del self.openai_vector_stores[store_id] async def _save_openai_vector_store_file( self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]] ) -> None: - """Save vector store file metadata to kvstore.""" + """Save vector store file data to kvstore.""" assert self.kvstore is not None key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}" await self.kvstore.set(key=key, value=json.dumps(file_info)) @@ -324,7 +328,16 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr await self.kvstore.set(key=key, value=json.dumps(file_info)) async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None: - """Delete vector store file metadata from kvstore.""" + """Delete vector store data from kvstore.""" assert self.kvstore is not None - key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}" - await self.kvstore.delete(key) + + keys_to_delete = [ + f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}", + f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}", + ] + for key in keys_to_delete: + try: + await self.kvstore.delete(key) + except Exception as e: + logger.warning(f"Failed to delete key {key}: {e}") + continue diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py index 4a9639326..9f86f877d 100644 --- a/tests/unit/providers/vector_io/conftest.py +++ b/tests/unit/providers/vector_io/conftest.py @@ -12,6 +12,8 @@ from pymilvus import MilvusClient, connections from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, ChunkMetadata +from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig +from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter @@ -90,7 +92,7 @@ def sample_embeddings_with_metadata(sample_chunks_with_metadata): return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks_with_metadata]) -@pytest.fixture(params=["milvus", "sqlite_vec"]) +@pytest.fixture(params=["milvus", "sqlite_vec", "faiss"]) def vector_provider(request): return request.param @@ -116,7 +118,7 @@ async def unique_kvstore_config(tmp_path_factory): @pytest.fixture(scope="session") def sqlite_vec_db_path(tmp_path_factory): - db_path = str(tmp_path_factory.getbasetemp() / "test.db") + db_path = str(tmp_path_factory.getbasetemp() / "test_sqlite_vec.db") return db_path @@ -198,11 +200,49 @@ async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api): await adapter.shutdown() +@pytest.fixture +def faiss_vec_db_path(tmp_path_factory): + db_path = str(tmp_path_factory.getbasetemp() / "test_faiss.db") + return db_path + + +@pytest.fixture +async def faiss_vec_index(embedding_dimension): + index = FaissIndex(embedding_dimension) + yield index + await index.delete() + + +@pytest.fixture +async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding_dimension): + config = FaissVectorIOConfig( + kvstore=unique_kvstore_config, + ) + adapter = FaissVectorIOAdapter( + config=config, + inference_api=mock_inference_api, + files_api=None, + ) + await adapter.initialize() + await adapter.register_vector_db( + VectorDB( + identifier=f"faiss_test_collection_{np.random.randint(1e6)}", + provider_id="test_provider", + embedding_model="test_model", + embedding_dimension=embedding_dimension, + ) + ) + yield adapter + await adapter.shutdown() + + @pytest.fixture def vector_io_adapter(vector_provider, request): """Returns the appropriate vector IO adapter based on the provider parameter.""" if vector_provider == "milvus": return request.getfixturevalue("milvus_vec_adapter") + elif vector_provider == "faiss": + return request.getfixturevalue("faiss_vec_adapter") else: return request.getfixturevalue("sqlite_vec_adapter") diff --git a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py index 97e2f085e..bf7663d2e 100644 --- a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py +++ b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py @@ -94,7 +94,7 @@ async def test_query_unregistered_raises(vector_io_adapter): async def test_insert_chunks_calls_underlying_index(vector_io_adapter): fake_index = AsyncMock() - vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index) + vector_io_adapter.cache["db1"] = fake_index chunks = ["chunk1", "chunk2"] await vector_io_adapter.insert_chunks("db1", chunks) @@ -112,7 +112,7 @@ async def test_insert_chunks_missing_db_raises(vector_io_adapter): async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter): expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1]) fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected)) - vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index) + vector_io_adapter.cache["db1"] = fake_index response = await vector_io_adapter.query_chunks("db1", "my_query", {"param": 1}) @@ -286,5 +286,7 @@ async def test_delete_openai_vector_store_file_from_storage(vector_io_adapter, t await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents) await vector_io_adapter._delete_openai_vector_store_file_from_storage(store_id, file_id) + loaded_file_info = await vector_io_adapter._load_openai_vector_store_file(store_id, file_id) + assert loaded_file_info == {} loaded_contents = await vector_io_adapter._load_openai_vector_store_file_contents(store_id, file_id) assert loaded_contents == [] From f731f369a2be1ad1d28ff8b3e1b9f59ae261451e Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Mon, 14 Jul 2025 14:38:53 -0400 Subject: [PATCH 07/12] feat: add infrastructure to allow inference model discovery (#2710) # What does this PR do? inference providers each have a static list of supported / known models. some also have access to a dynamic list of currently available models. this change gives prodivers using the ModelRegistryHelper the ability to combine their static and dynamic lists. for instance, OpenAIInferenceAdapter can implement ``` def query_available_models(self) -> list[str]: return [entry.model for entry in self.openai_client.models.list()] ``` to augment its static list w/ a current list from openai. ## Test Plan scripts/unit-test.sh --- .../utils/inference/model_registry.py | 33 ++++++- .../providers/utils/test_model_registry.py | 91 +++++++++++++++++++ 2 files changed, 122 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index c2fc13e07..801b8ea06 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -83,9 +83,37 @@ class ModelRegistryHelper(ModelsProtocolPrivate): def get_llama_model(self, provider_model_id: str) -> str | None: return self.provider_id_to_llama_model_map.get(provider_model_id, None) + async def check_model_availability(self, model: str) -> bool: + """ + Check if a specific model is available from the provider (non-static check). + + This is for subclassing purposes, so providers can check if a specific + model is currently available for use through dynamic means (e.g., API calls). + + This method should NOT check statically configured model entries in + `self.alias_to_provider_id_map` - that is handled separately in register_model. + + Default implementation returns False (no dynamic models available). + + :param model: The model identifier to check. + :return: True if the model is available dynamically, False otherwise. + """ + return False + async def register_model(self, model: Model) -> Model: - if not (supported_model_id := self.get_provider_model_id(model.provider_resource_id)): - raise UnsupportedModelError(model.provider_resource_id, self.alias_to_provider_id_map.keys()) + # Check if model is supported in static configuration + supported_model_id = self.get_provider_model_id(model.provider_resource_id) + + # If not found in static config, check if it's available dynamically from provider + if not supported_model_id: + if await self.check_model_availability(model.provider_resource_id): + supported_model_id = model.provider_resource_id + else: + # note: we cannot provide a complete list of supported models without + # getting a complete list from the provider, so we return "..." + all_supported_models = [*self.alias_to_provider_id_map.keys(), "..."] + raise UnsupportedModelError(model.provider_resource_id, all_supported_models) + provider_resource_id = self.get_provider_model_id(model.model_id) if model.model_type == ModelType.embedding: # embedding models are always registered by their provider model id and does not need to be mapped to a llama model @@ -114,6 +142,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate): ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model] ) + # Register the model alias, ensuring it maps to the correct provider model id self.alias_to_provider_id_map[model.model_id] = supported_model_id return model diff --git a/tests/unit/providers/utils/test_model_registry.py b/tests/unit/providers/utils/test_model_registry.py index e11f95d49..1a1705961 100644 --- a/tests/unit/providers/utils/test_model_registry.py +++ b/tests/unit/providers/utils/test_model_registry.py @@ -87,6 +87,37 @@ def helper(known_provider_model: ProviderModelEntry, known_provider_model2: Prov return ModelRegistryHelper([known_provider_model, known_provider_model2]) +class MockModelRegistryHelperWithDynamicModels(ModelRegistryHelper): + """Test helper that simulates a provider with dynamically available models.""" + + def __init__(self, model_entries: list[ProviderModelEntry], available_models: list[str]): + super().__init__(model_entries) + self._available_models = available_models + + async def check_model_availability(self, model: str) -> bool: + return model in self._available_models + + +@pytest.fixture +def dynamic_model() -> Model: + """A model that's not in static config but available dynamically.""" + return Model( + provider_id="provider", + identifier="dynamic-model", + provider_resource_id="dynamic-provider-id", + ) + + +@pytest.fixture +def helper_with_dynamic_models( + known_provider_model: ProviderModelEntry, known_provider_model2: ProviderModelEntry, dynamic_model: Model +) -> MockModelRegistryHelperWithDynamicModels: + """Helper that includes dynamically available models.""" + return MockModelRegistryHelperWithDynamicModels( + [known_provider_model, known_provider_model2], [dynamic_model.provider_resource_id] + ) + + async def test_lookup_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None: assert helper.get_provider_model_id(unknown_model.model_id) is None @@ -151,3 +182,63 @@ async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_m assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id await helper.unregister_model(known_model.provider_resource_id) assert helper.get_provider_model_id(known_model.provider_resource_id) is None + + +async def test_register_model_from_check_model_availability( + helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, dynamic_model: Model +) -> None: + """Test that models returned by check_model_availability can be registered.""" + # Verify the model is not in static config + assert helper_with_dynamic_models.get_provider_model_id(dynamic_model.provider_resource_id) is None + + # But it should be available via check_model_availability + is_available = await helper_with_dynamic_models.check_model_availability(dynamic_model.provider_resource_id) + assert is_available + + # Registration should succeed + registered_model = await helper_with_dynamic_models.register_model(dynamic_model) + assert registered_model == dynamic_model + + # Model should now be registered and accessible + assert ( + helper_with_dynamic_models.get_provider_model_id(dynamic_model.model_id) == dynamic_model.provider_resource_id + ) + + +async def test_register_model_not_in_static_or_dynamic( + helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, unknown_model: Model +) -> None: + """Test that models not in static config or dynamic models are rejected.""" + # Verify the model is not in static config + assert helper_with_dynamic_models.get_provider_model_id(unknown_model.provider_resource_id) is None + + # And not available via check_model_availability + is_available = await helper_with_dynamic_models.check_model_availability(unknown_model.provider_resource_id) + assert not is_available + + # Registration should fail with comprehensive error message + with pytest.raises(Exception) as exc_info: # UnsupportedModelError + await helper_with_dynamic_models.register_model(unknown_model) + + # Error should include static models and "..." for dynamic models + error_str = str(exc_info.value) + assert "..." in error_str # "..." should be in error message + + +async def test_register_alias_for_dynamic_model( + helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, dynamic_model: Model +) -> None: + """Test that we can register an alias that maps to a dynamically available model.""" + # Create a model with a different identifier but same provider_resource_id + alias_model = Model( + provider_id=dynamic_model.provider_id, + identifier="dynamic-model-alias", + provider_resource_id=dynamic_model.provider_resource_id, + ) + + # Registration should succeed since the provider_resource_id is available dynamically + registered_model = await helper_with_dynamic_models.register_model(alias_model) + assert registered_model == alias_model + + # Both the original provider_resource_id and the new alias should work + assert helper_with_dynamic_models.get_provider_model_id(alias_model.model_id) == dynamic_model.provider_resource_id From aa0840c2813c6140840dd91d5a40ee3850a6a889 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Mon, 14 Jul 2025 12:06:56 -0700 Subject: [PATCH 08/12] docs: fix building distro link (#2750) # What does this PR do? ## Test Plan Co-authored-by: raghotham --- docs/source/getting_started/detailed_tutorial.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/getting_started/detailed_tutorial.md b/docs/source/getting_started/detailed_tutorial.md index 97e7df774..fc59022f9 100644 --- a/docs/source/getting_started/detailed_tutorial.md +++ b/docs/source/getting_started/detailed_tutorial.md @@ -77,7 +77,7 @@ ENABLE_OLLAMA=ollama INFERENCE_MODEL="llama3.2:3b" llama stack build --template You can use a container image to run the Llama Stack server. We provide several container images for the server component that works with different inference providers out of the box. For this guide, we will use `llamastack/distribution-starter` as the container image. If you'd like to build your own image or customize the -configurations, please check out [this guide](../references/index.md). +configurations, please check out [this guide](../distributions/building_distro.md). First lets setup some environment variables and create a local directory to mount into the container’s file system. ```bash export INFERENCE_MODEL="llama3.2:3b" From 6ad22c209ff53f996a30d94f9769c9e405f3bc71 Mon Sep 17 00:00:00 2001 From: Nathan Weinberg <31703736+nathan-weinberg@users.noreply.github.com> Date: Mon, 14 Jul 2025 17:41:44 -0400 Subject: [PATCH 09/12] chore: add issue template for technical debt (#2753) # What does this PR do? Adds a template for technical debt. Currently we don't support blank issues so everything filed has to a bug or a feature. This would allow maintainers as well as community members to track things we might want to merge to expose the functionality but should be addressed later. Such things can also be "good first issues" for new contributors. ## Example of what we constitute as technical debt Inelegant code solutions, tests we intend to temporarily disable but would like to restore, CI hacks around infrastructure or installation, etc. Signed-off-by: Nathan Weinberg --- .github/ISSUE_TEMPLATE/tech-debt.yml | 30 ++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/tech-debt.yml diff --git a/.github/ISSUE_TEMPLATE/tech-debt.yml b/.github/ISSUE_TEMPLATE/tech-debt.yml new file mode 100644 index 000000000..b281b3482 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/tech-debt.yml @@ -0,0 +1,30 @@ +name: 🔧 Tech Debt +description: Something that is functional but should be improved or optimizied +labels: ["tech-debt"] +body: +- type: textarea + id: tech-debt-explanation + attributes: + label: 🤔 What is the technical debt you think should be addressed? + description: > + A clear and concise description of _what_ needs to be addressed - ensure you are describing + constitutes [technical debt](https://en.wikipedia.org/wiki/Technical_debt) and is not a bug + or feature request. + validations: + required: true + +- type: textarea + id: tech-debt-motivation + attributes: + label: 💡 What is the benefit of addressing this technical debt? + description: > + A clear and concise description of _why_ this work is needed. + validations: + required: true + +- type: textarea + id: other-thoughts + attributes: + label: Other thoughts + description: > + Any thoughts about how this may result in complexity in the codebase, or other trade-offs. From 6b8a8c1be9c61cf1135422f2105c4cc0040e19fc Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Mon, 14 Jul 2025 15:07:40 -0700 Subject: [PATCH 10/12] fix: Safety in starter (#2731) - fireworks, together do not support Llama-guard 3 8b model anymore - Need to default to ollama - current safety shields logic was not correct since the shield_id was the provider ( which had duplicates ) - Followed similar logic to models Note: Seems a bit over-engineered but this can now be extended to other providers and fits in the overall mechanism of how env_vars are used to manage starter. ### How to test ``` ENABLE_OLLAMA=ollama ENABLE_FIREWORKS=fireworks SAFETY_MODEL=llama-guard3:1b pytest -s -v tests/integration/ --stack-config starter -k 'not(supervised_fine_tune or builtin_tool_code or safety_with_image or code_interpreter_for or rag_and_code or truncation or register_and_unregister)' --text-model fireworks/meta-llama/Llama-3.3-70B-Instruct --vision-model fireworks/meta-llama/Llama-4-Scout-17B-16E-Instruct --safety-shield llama-guard3:1b --embedding-model all-MiniLM-L6-v2 ``` ### Related but not obvious in this PR In the llama-stack-ops repo, we run tests before publishing packages and docker containers. The actions in that repo were using the fireworks / together distros ( which are non-existent ) So need to update that to run with `starter` and use `ollama` specifically for safety. --- .github/workflows/integration-tests.yml | 2 +- .../remote/inference/ollama/models.py | 25 +++-- llama_stack/templates/nvidia/nvidia.py | 2 +- .../open-benchmark/open_benchmark.py | 3 +- llama_stack/templates/starter/run.yaml | 20 +--- llama_stack/templates/starter/starter.py | 94 +++------------- llama_stack/templates/template.py | 47 +++++++- llama_stack/templates/watsonx/watsonx.py | 2 +- tests/integration/agents/test_agents.py | 104 ++++-------------- 9 files changed, 104 insertions(+), 195 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index c46100c38..1a8d6734f 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -89,7 +89,7 @@ jobs: -k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \ --text-model="ollama/llama3.2:3b-instruct-fp16" \ --embedding-model=all-MiniLM-L6-v2 \ - --safety-shield=ollama \ + --safety-shield=$SAFETY_MODEL \ --color=yes \ --capture=tee-sys | tee pytest-${{ matrix.test-type }}.log diff --git a/llama_stack/providers/remote/inference/ollama/models.py b/llama_stack/providers/remote/inference/ollama/models.py index 64ddb23d9..7c0a19a1a 100644 --- a/llama_stack/providers/remote/inference/ollama/models.py +++ b/llama_stack/providers/remote/inference/ollama/models.py @@ -12,6 +12,19 @@ from llama_stack.providers.utils.inference.model_registry import ( build_model_entry, ) +SAFETY_MODELS_ENTRIES = [ + # The Llama Guard models don't have their full fp16 versions + # so we are going to alias their default version to the canonical SKU + build_hf_repo_model_entry( + "llama-guard3:8b", + CoreModelId.llama_guard_3_8b.value, + ), + build_hf_repo_model_entry( + "llama-guard3:1b", + CoreModelId.llama_guard_3_1b.value, + ), +] + MODEL_ENTRIES = [ build_hf_repo_model_entry( "llama3.1:8b-instruct-fp16", @@ -73,16 +86,6 @@ MODEL_ENTRIES = [ "llama3.3:70b", CoreModelId.llama3_3_70b_instruct.value, ), - # The Llama Guard models don't have their full fp16 versions - # so we are going to alias their default version to the canonical SKU - build_hf_repo_model_entry( - "llama-guard3:8b", - CoreModelId.llama_guard_3_8b.value, - ), - build_hf_repo_model_entry( - "llama-guard3:1b", - CoreModelId.llama_guard_3_1b.value, - ), ProviderModelEntry( provider_model_id="all-minilm:l6-v2", aliases=["all-minilm"], @@ -100,4 +103,4 @@ MODEL_ENTRIES = [ "context_length": 8192, }, ), -] +] + SAFETY_MODELS_ENTRIES diff --git a/llama_stack/templates/nvidia/nvidia.py b/llama_stack/templates/nvidia/nvidia.py index 4eccfb25c..e5c13aa74 100644 --- a/llama_stack/templates/nvidia/nvidia.py +++ b/llama_stack/templates/nvidia/nvidia.py @@ -68,7 +68,7 @@ def get_distribution_template() -> DistributionTemplate: ), ] - default_models = get_model_registry(available_models) + default_models, _ = get_model_registry(available_models) return DistributionTemplate( name="nvidia", distro_type="self_hosted", diff --git a/llama_stack/templates/open-benchmark/open_benchmark.py b/llama_stack/templates/open-benchmark/open_benchmark.py index 942905dae..56ee9c47d 100644 --- a/llama_stack/templates/open-benchmark/open_benchmark.py +++ b/llama_stack/templates/open-benchmark/open_benchmark.py @@ -146,7 +146,8 @@ def get_distribution_template() -> DistributionTemplate: ), ] - default_models = get_model_registry(available_models) + [ + models, _ = get_model_registry(available_models) + default_models = models + [ ModelInput( model_id="meta-llama/Llama-3.3-70B-Instruct", provider_id="groq", diff --git a/llama_stack/templates/starter/run.yaml b/llama_stack/templates/starter/run.yaml index 888a2c3bf..ad449cb1b 100644 --- a/llama_stack/templates/starter/run.yaml +++ b/llama_stack/templates/starter/run.yaml @@ -1171,24 +1171,8 @@ models: provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers} model_type: embedding shields: -- shield_id: ${env.ENABLE_OLLAMA:=__disabled__} - provider_id: llama-guard - provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=llama-guard3:1b} -- shield_id: ${env.ENABLE_FIREWORKS:=__disabled__} - provider_id: llama-guard - provider_shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}/${env.SAFETY_MODEL:=accounts/fireworks/models/llama-guard-3-8b} -- shield_id: ${env.ENABLE_FIREWORKS:=__disabled__} - provider_id: llama-guard - provider_shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}/${env.SAFETY_MODEL:=accounts/fireworks/models/llama-guard-3-11b-vision} -- shield_id: ${env.ENABLE_TOGETHER:=__disabled__} - provider_id: llama-guard - provider_shield_id: ${env.ENABLE_TOGETHER:=__disabled__}/${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-8B} -- shield_id: ${env.ENABLE_TOGETHER:=__disabled__} - provider_id: llama-guard - provider_shield_id: ${env.ENABLE_TOGETHER:=__disabled__}/${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-11B-Vision-Turbo} -- shield_id: ${env.ENABLE_SAMBANOVA:=__disabled__} - provider_id: llama-guard - provider_shield_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/${env.SAFETY_MODEL:=sambanova/Meta-Llama-Guard-3-8B} +- shield_id: ${env.SAFETY_MODEL:=__disabled__} + provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=__disabled__} vector_dbs: [] datasets: [] scoring_fns: [] diff --git a/llama_stack/templates/starter/starter.py b/llama_stack/templates/starter/starter.py index 6b8aa8974..c0ac44183 100644 --- a/llama_stack/templates/starter/starter.py +++ b/llama_stack/templates/starter/starter.py @@ -12,7 +12,6 @@ from llama_stack.distribution.datatypes import ( ModelInput, Provider, ProviderSpec, - ShieldInput, ToolGroupInput, ) from llama_stack.distribution.utils.dynamic import instantiate_class_type @@ -32,75 +31,39 @@ from llama_stack.providers.registry.inference import available_providers from llama_stack.providers.remote.inference.anthropic.models import ( MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.anthropic.models import ( - SAFETY_MODELS_ENTRIES as ANTHROPIC_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.bedrock.models import ( MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.bedrock.models import ( - SAFETY_MODELS_ENTRIES as BEDROCK_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.cerebras.models import ( MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.cerebras.models import ( - SAFETY_MODELS_ENTRIES as CEREBRAS_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.databricks.databricks import ( MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.databricks.databricks import ( - SAFETY_MODELS_ENTRIES as DATABRICKS_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.fireworks.models import ( MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.fireworks.models import ( - SAFETY_MODELS_ENTRIES as FIREWORKS_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.gemini.models import ( MODEL_ENTRIES as GEMINI_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.gemini.models import ( - SAFETY_MODELS_ENTRIES as GEMINI_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.groq.models import ( MODEL_ENTRIES as GROQ_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.groq.models import ( - SAFETY_MODELS_ENTRIES as GROQ_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.nvidia.models import ( MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.nvidia.models import ( - SAFETY_MODELS_ENTRIES as NVIDIA_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.openai.models import ( MODEL_ENTRIES as OPENAI_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.openai.models import ( - SAFETY_MODELS_ENTRIES as OPENAI_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.runpod.runpod import ( MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.runpod.runpod import ( - SAFETY_MODELS_ENTRIES as RUNPOD_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.sambanova.models import ( MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.sambanova.models import ( - SAFETY_MODELS_ENTRIES as SAMBANOVA_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.inference.together.models import ( MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES, ) -from llama_stack.providers.remote.inference.together.models import ( - SAFETY_MODELS_ENTRIES as TOGETHER_SAFETY_MODELS_ENTRIES, -) from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig from llama_stack.providers.remote.vector_io.pgvector.config import ( PGVectorVectorIOConfig, @@ -111,6 +74,7 @@ from llama_stack.templates.template import ( DistributionTemplate, RunConfigSettings, get_model_registry, + get_shield_registry, ) @@ -164,28 +128,13 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]: """Get model entries for a specific provider type.""" safety_model_entries_map = { - "openai": OPENAI_SAFETY_MODELS_ENTRIES, - "fireworks": FIREWORKS_SAFETY_MODELS_ENTRIES, - "together": TOGETHER_SAFETY_MODELS_ENTRIES, - "anthropic": ANTHROPIC_SAFETY_MODELS_ENTRIES, - "gemini": GEMINI_SAFETY_MODELS_ENTRIES, - "groq": GROQ_SAFETY_MODELS_ENTRIES, - "sambanova": SAMBANOVA_SAFETY_MODELS_ENTRIES, - "cerebras": CEREBRAS_SAFETY_MODELS_ENTRIES, - "bedrock": BEDROCK_SAFETY_MODELS_ENTRIES, - "databricks": DATABRICKS_SAFETY_MODELS_ENTRIES, - "nvidia": NVIDIA_SAFETY_MODELS_ENTRIES, - "runpod": RUNPOD_SAFETY_MODELS_ENTRIES, - } - - # Special handling for providers with dynamic model entries - if provider_type == "ollama": - return [ + "ollama": [ ProviderModelEntry( - provider_model_id="llama-guard3:1b", + provider_model_id="${env.SAFETY_MODEL:=__disabled__}", model_type=ModelType.llm, ), - ] + ], + } return safety_model_entries_map.get(provider_type, []) @@ -246,28 +195,20 @@ def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[Pro # build a list of shields for all possible providers -def get_shields_for_providers(providers: list[Provider]) -> list[ShieldInput]: - shields = [] +def get_safety_models_for_providers(providers: list[Provider]) -> dict[str, list[ProviderModelEntry]]: + available_models = {} for provider in providers: provider_type = provider.provider_type.split("::")[1] safety_model_entries = _get_model_safety_entries_for_provider(provider_type) if len(safety_model_entries) == 0: continue - if provider.provider_id: - shield_id = provider.provider_id - else: - raise ValueError(f"Provider {provider.provider_type} has no provider_id") - for safety_model_entry in safety_model_entries: - print(f"provider.provider_id: {provider.provider_id}") - print(f"safety_model_entry.provider_model_id: {safety_model_entry.provider_model_id}") - shields.append( - ShieldInput( - provider_id="llama-guard", - shield_id=shield_id, - provider_shield_id=f"{provider.provider_id}/${{env.SAFETY_MODEL:={safety_model_entry.provider_model_id}}}", - ) - ) - return shields + + env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}" + provider_id = f"${{env.{env_var}:=__disabled__}}" + + available_models[provider_id] = safety_model_entries + + return available_models def get_distribution_template() -> DistributionTemplate: @@ -307,8 +248,6 @@ def get_distribution_template() -> DistributionTemplate: ), ] - shields = get_shields_for_providers(remote_inference_providers) - providers = { "inference": ([p.provider_type for p in remote_inference_providers] + ["inline::sentence-transformers"]), "vector_io": ([p.provider_type for p in vector_io_providers]), @@ -361,7 +300,10 @@ def get_distribution_template() -> DistributionTemplate: }, ) - default_models = get_model_registry(available_models) + default_models, ids_conflict_in_models = get_model_registry(available_models) + + available_safety_models = get_safety_models_for_providers(remote_inference_providers) + shields = get_shield_registry(available_safety_models, ids_conflict_in_models) return DistributionTemplate( name=name, diff --git a/llama_stack/templates/template.py b/llama_stack/templates/template.py index dceb13c8b..fb2528873 100644 --- a/llama_stack/templates/template.py +++ b/llama_stack/templates/template.py @@ -37,7 +37,7 @@ from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as ge def get_model_registry( available_models: dict[str, list[ProviderModelEntry]], -) -> list[ModelInput]: +) -> tuple[list[ModelInput], bool]: models = [] # check for conflicts in model ids @@ -74,7 +74,50 @@ def get_model_registry( metadata=entry.metadata, ) ) - return models + return models, ids_conflict + + +def get_shield_registry( + available_safety_models: dict[str, list[ProviderModelEntry]], + ids_conflict_in_models: bool, +) -> list[ShieldInput]: + shields = [] + + # check for conflicts in shield ids + all_ids = set() + ids_conflict = False + + for _, entries in available_safety_models.items(): + for entry in entries: + ids = [entry.provider_model_id] + entry.aliases + for model_id in ids: + if model_id in all_ids: + ids_conflict = True + rich.print( + f"[yellow]Shield id {model_id} conflicts; all shield ids will be prefixed with provider id[/yellow]" + ) + break + all_ids.update(ids) + if ids_conflict: + break + if ids_conflict: + break + + for provider_id, entries in available_safety_models.items(): + for entry in entries: + ids = [entry.provider_model_id] + entry.aliases + for model_id in ids: + identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id + shields.append( + ShieldInput( + shield_id=identifier, + provider_shield_id=f"{provider_id}/{entry.provider_model_id}" + if ids_conflict_in_models + else entry.provider_model_id, + ) + ) + + return shields class DefaultModel(BaseModel): diff --git a/llama_stack/templates/watsonx/watsonx.py b/llama_stack/templates/watsonx/watsonx.py index 7fa3a55e5..ea185f05d 100644 --- a/llama_stack/templates/watsonx/watsonx.py +++ b/llama_stack/templates/watsonx/watsonx.py @@ -69,7 +69,7 @@ def get_distribution_template() -> DistributionTemplate: }, ) - default_models = get_model_registry(available_models) + default_models, _ = get_model_registry(available_models) return DistributionTemplate( name="watsonx", distro_type="remote_hosted", diff --git a/tests/integration/agents/test_agents.py b/tests/integration/agents/test_agents.py index 66c9ab829..05549cf18 100644 --- a/tests/integration/agents/test_agents.py +++ b/tests/integration/agents/test_agents.py @@ -77,6 +77,24 @@ def agent_config(llama_stack_client, text_model_id): return agent_config +@pytest.fixture(scope="session") +def agent_config_without_safety(text_model_id): + agent_config = dict( + model=text_model_id, + instructions="You are a helpful assistant", + sampling_params={ + "strategy": { + "type": "top_p", + "temperature": 0.0001, + "top_p": 0.9, + }, + }, + tools=[], + enable_session_persistence=False, + ) + return agent_config + + def test_agent_simple(llama_stack_client, agent_config): agent = Agent(llama_stack_client, **agent_config) session_id = agent.create_session(f"test-session-{uuid4()}") @@ -491,7 +509,7 @@ def test_rag_agent(llama_stack_client, agent_config, rag_tool_name): assert expected_kw in response.output_message.content.lower() -def test_rag_agent_with_attachments(llama_stack_client, agent_config): +def test_rag_agent_with_attachments(llama_stack_client, agent_config_without_safety): urls = ["llama3.rst", "lora_finetune.rst"] documents = [ # passign as url @@ -514,14 +532,8 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config): metadata={}, ), ] - rag_agent = Agent(llama_stack_client, **agent_config) + rag_agent = Agent(llama_stack_client, **agent_config_without_safety) session_id = rag_agent.create_session(f"test-session-{uuid4()}") - user_prompts = [ - ( - "Instead of the standard multi-head attention, what attention type does Llama3-8B use?", - "grouped", - ), - ] user_prompts = [ ( "I am attaching some documentation for Torchtune. Help me answer questions I will ask next.", @@ -549,82 +561,6 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config): assert "lora" in response.output_message.content.lower() -@pytest.mark.skip(reason="Code interpreter is currently disabled in the Stack") -def test_rag_and_code_agent(llama_stack_client, agent_config): - if "llama-4" in agent_config["model"].lower(): - pytest.xfail("Not working for llama4") - - documents = [] - documents.append( - Document( - document_id="nba_wiki", - content="The NBA was created on August 3, 1949, with the merger of the Basketball Association of America (BAA) and the National Basketball League (NBL).", - metadata={}, - ) - ) - documents.append( - Document( - document_id="perplexity_wiki", - content="""Perplexity the company was founded in 2022 by Aravind Srinivas, Andy Konwinski, Denis Yarats and Johnny Ho, engineers with backgrounds in back-end systems, artificial intelligence (AI) and machine learning: - - Srinivas, the CEO, worked at OpenAI as an AI researcher. - Konwinski was among the founding team at Databricks. - Yarats, the CTO, was an AI research scientist at Meta. - Ho, the CSO, worked as an engineer at Quora, then as a quantitative trader on Wall Street.[5]""", - metadata={}, - ) - ) - vector_db_id = f"test-vector-db-{uuid4()}" - llama_stack_client.vector_dbs.register( - vector_db_id=vector_db_id, - embedding_model="all-MiniLM-L6-v2", - embedding_dimension=384, - ) - llama_stack_client.tool_runtime.rag_tool.insert( - documents=documents, - vector_db_id=vector_db_id, - chunk_size_in_tokens=128, - ) - agent_config = { - **agent_config, - "tools": [ - dict( - name="builtin::rag/knowledge_search", - args={"vector_db_ids": [vector_db_id]}, - ), - "builtin::code_interpreter", - ], - } - agent = Agent(llama_stack_client, **agent_config) - user_prompts = [ - ( - "when was Perplexity the company founded?", - [], - "knowledge_search", - "2022", - ), - ( - "when was the nba created?", - [], - "knowledge_search", - "1949", - ), - ] - - for prompt, docs, tool_name, expected_kw in user_prompts: - session_id = agent.create_session(f"test-session-{uuid4()}") - response = agent.create_turn( - messages=[{"role": "user", "content": prompt}], - session_id=session_id, - documents=docs, - stream=False, - ) - tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution") - assert tool_execution_step.tool_calls[0].tool_name == tool_name, f"Failed on {prompt}" - if expected_kw: - assert expected_kw in response.output_message.content.lower() - - @pytest.mark.parametrize( "client_tools", [(get_boiling_point, False), (get_boiling_point_with_metadata, True)], From 33f0d83ad396bf647ee999105272829e86321d38 Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Mon, 14 Jul 2025 18:10:35 -0400 Subject: [PATCH 11/12] chore: Move vector store `kvstore` implementation into `openai_vector_store_mixin.py` (#2748) --- .../providers/vector_io/remote_pgvector.md | 4 + .../providers/vector_io/remote_weaviate.md | 4 +- .../providers/inline/vector_io/faiss/faiss.py | 40 +---- .../inline/vector_io/sqlite_vec/sqlite_vec.py | 39 +--- .../remote/vector_io/milvus/milvus.py | 33 +--- .../remote/vector_io/pgvector/config.py | 18 +- .../remote/vector_io/pgvector/pgvector.py | 169 ++++++------------ .../remote/vector_io/weaviate/config.py | 17 +- .../remote/vector_io/weaviate/weaviate.py | 64 ++++++- .../utils/memory/openai_vector_store_mixin.py | 41 ++++- .../open-benchmark/open_benchmark.py | 1 + llama_stack/templates/open-benchmark/run.yaml | 3 + llama_stack/templates/starter/run.yaml | 3 + llama_stack/templates/starter/starter.py | 1 + 14 files changed, 203 insertions(+), 234 deletions(-) diff --git a/docs/source/providers/vector_io/remote_pgvector.md b/docs/source/providers/vector_io/remote_pgvector.md index 685b98f37..3e7d6e776 100644 --- a/docs/source/providers/vector_io/remote_pgvector.md +++ b/docs/source/providers/vector_io/remote_pgvector.md @@ -40,6 +40,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de | `db` | `str \| None` | No | postgres | | | `user` | `str \| None` | No | postgres | | | `password` | `str \| None` | No | mysecretpassword | | +| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) | ## Sample Configuration @@ -49,6 +50,9 @@ port: ${env.PGVECTOR_PORT:=5432} db: ${env.PGVECTOR_DB} user: ${env.PGVECTOR_USER} password: ${env.PGVECTOR_PASSWORD} +kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/pgvector_registry.db ``` diff --git a/docs/source/providers/vector_io/remote_weaviate.md b/docs/source/providers/vector_io/remote_weaviate.md index b7f811c35..d930515d5 100644 --- a/docs/source/providers/vector_io/remote_weaviate.md +++ b/docs/source/providers/vector_io/remote_weaviate.md @@ -36,7 +36,9 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more ## Sample Configuration ```yaml -{} +kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/weaviate_registry.db ``` diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index 0306d9156..2a1370c56 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -181,8 +181,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr ) self.cache[vector_db.identifier] = index - # Load existing OpenAI vector stores using the mixin method - self.openai_vector_stores = await self._load_openai_vector_stores() + # Load existing OpenAI vector stores into the in-memory cache + await self.initialize_openai_vector_stores() async def shutdown(self) -> None: # Cleanup if needed @@ -261,42 +261,6 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr return await index.query_chunks(query, params) - # OpenAI Vector Store Mixin abstract method implementations - async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: - """Save vector store metadata to kvstore.""" - assert self.kvstore is not None - key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" - await self.kvstore.set(key=key, value=json.dumps(store_info)) - self.openai_vector_stores[store_id] = store_info - - async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]: - """Load all vector store metadata from kvstore.""" - assert self.kvstore is not None - start_key = OPENAI_VECTOR_STORES_PREFIX - end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff" - stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key) - - stores = {} - for store_data in stored_openai_stores: - store_info = json.loads(store_data) - stores[store_info["id"]] = store_info - return stores - - async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: - """Update vector store metadata in kvstore.""" - assert self.kvstore is not None - key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" - await self.kvstore.set(key=key, value=json.dumps(store_info)) - self.openai_vector_stores[store_id] = store_info - - async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None: - """Delete vector store metadata from kvstore.""" - assert self.kvstore is not None - key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" - await self.kvstore.delete(key) - if store_id in self.openai_vector_stores: - del self.openai_vector_stores[store_id] - async def _save_openai_vector_store_file( self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]] ) -> None: diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index 6acd85c56..771ffa607 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -452,8 +452,8 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc ) self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) - # load any existing OpenAI vector stores - self.openai_vector_stores = await self._load_openai_vector_stores() + # Load existing OpenAI vector stores into the in-memory cache + await self.initialize_openai_vector_stores() async def shutdown(self) -> None: # nothing to do since we don't maintain a persistent connection @@ -501,41 +501,6 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc await self.cache[vector_db_id].index.delete() del self.cache[vector_db_id] - # OpenAI Vector Store Mixin abstract method implementations - async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: - """Save vector store metadata to SQLite database.""" - assert self.kvstore is not None - key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" - await self.kvstore.set(key=key, value=json.dumps(store_info)) - self.openai_vector_stores[store_id] = store_info - - async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]: - """Load all vector store metadata from SQLite database.""" - assert self.kvstore is not None - start_key = OPENAI_VECTOR_STORES_PREFIX - end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff" - stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key) - stores = {} - for store_data in stored_openai_stores: - store_info = json.loads(store_data) - stores[store_info["id"]] = store_info - return stores - - async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: - """Update vector store metadata in SQLite database.""" - assert self.kvstore is not None - key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" - await self.kvstore.set(key=key, value=json.dumps(store_info)) - self.openai_vector_stores[store_id] = store_info - - async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None: - """Delete vector store metadata from SQLite database.""" - assert self.kvstore is not None - key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" - await self.kvstore.delete(key) - if store_id in self.openai_vector_stores: - del self.openai_vector_stores[store_id] - async def _save_openai_vector_store_file( self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]] ) -> None: diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index a06130fd0..d88d954ef 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -179,7 +179,8 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP uri = os.path.expanduser(self.config.db_path) self.client = MilvusClient(uri=uri) - self.openai_vector_stores = await self._load_openai_vector_stores() + # Load existing OpenAI vector stores into the in-memory cache + await self.initialize_openai_vector_stores() async def shutdown(self) -> None: self.client.close() @@ -248,36 +249,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP return await index.query_chunks(query, params) - async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: - """Save vector store metadata to persistent storage.""" - assert self.kvstore is not None - key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" - await self.kvstore.set(key=key, value=json.dumps(store_info)) - self.openai_vector_stores[store_id] = store_info - - async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: - """Update vector store metadata in persistent storage.""" - assert self.kvstore is not None - key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" - await self.kvstore.set(key=key, value=json.dumps(store_info)) - self.openai_vector_stores[store_id] = store_info - - async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None: - """Delete vector store metadata from persistent storage.""" - assert self.kvstore is not None - key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" - await self.kvstore.delete(key) - if store_id in self.openai_vector_stores: - del self.openai_vector_stores[store_id] - - async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]: - """Load all vector store metadata from persistent storage.""" - assert self.kvstore is not None - start_key = OPENAI_VECTOR_STORES_PREFIX - end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff" - stored = await self.kvstore.values_in_range(start_key, end_key) - return {json.loads(s)["id"]: json.loads(s) for s in stored} - async def _save_openai_vector_store_file( self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]] ) -> None: diff --git a/llama_stack/providers/remote/vector_io/pgvector/config.py b/llama_stack/providers/remote/vector_io/pgvector/config.py index 92908aa8a..334cbe5be 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/config.py +++ b/llama_stack/providers/remote/vector_io/pgvector/config.py @@ -8,6 +8,10 @@ from typing import Any from pydantic import BaseModel, Field +from llama_stack.providers.utils.kvstore.config import ( + KVStoreConfig, + SqliteKVStoreConfig, +) from llama_stack.schema_utils import json_schema_type @@ -18,10 +22,12 @@ class PGVectorVectorIOConfig(BaseModel): db: str | None = Field(default="postgres") user: str | None = Field(default="postgres") password: str | None = Field(default="mysecretpassword") + kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None) @classmethod def sample_run_config( cls, + __distro_dir__: str, host: str = "${env.PGVECTOR_HOST:=localhost}", port: int = "${env.PGVECTOR_PORT:=5432}", db: str = "${env.PGVECTOR_DB}", @@ -29,4 +35,14 @@ class PGVectorVectorIOConfig(BaseModel): password: str = "${env.PGVECTOR_PASSWORD}", **kwargs: Any, ) -> dict[str, Any]: - return {"host": host, "port": port, "db": db, "user": user, "password": password} + return { + "host": host, + "port": port, + "db": db, + "user": user, + "password": password, + "kvstore": SqliteKVStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, + db_name="pgvector_registry.db", + ), + } diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index c3cdef9b8..1bf3eedf8 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -13,24 +13,18 @@ from psycopg2 import sql from psycopg2.extras import Json, execute_values from pydantic import BaseModel, TypeAdapter +from llama_stack.apis.files.files import Files from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import ( Chunk, QueryChunksResponse, - SearchRankingOptions, VectorIO, - VectorStoreChunkingStrategy, - VectorStoreDeleteResponse, - VectorStoreFileContentsResponse, - VectorStoreFileObject, - VectorStoreFileStatus, - VectorStoreListFilesResponse, - VectorStoreListResponse, - VectorStoreObject, - VectorStoreSearchResponsePage, ) from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate +from llama_stack.providers.utils.kvstore import kvstore_impl +from llama_stack.providers.utils.kvstore.api import KVStore +from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, VectorDBWithIndex, @@ -40,6 +34,13 @@ from .config import PGVectorVectorIOConfig log = logging.getLogger(__name__) +VERSION = "v3" +VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::" +VECTOR_INDEX_PREFIX = f"vector_index:pgvector:{VERSION}::" +OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:pgvector:{VERSION}::" +OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:pgvector:{VERSION}::" +OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:pgvector:{VERSION}::" + def check_extension_version(cur): cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'") @@ -69,7 +70,7 @@ def load_models(cur, cls): class PGVectorIndex(EmbeddingIndex): - def __init__(self, vector_db: VectorDB, dimension: int, conn): + def __init__(self, vector_db: VectorDB, dimension: int, conn, kvstore: KVStore | None = None): self.conn = conn with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: # Sanitize the table name by replacing hyphens with underscores @@ -77,6 +78,7 @@ class PGVectorIndex(EmbeddingIndex): # when created with patterns like "test-vector-db-{uuid4()}" sanitized_identifier = vector_db.identifier.replace("-", "_") self.table_name = f"vector_store_{sanitized_identifier}" + self.kvstore = kvstore cur.execute( f""" @@ -158,15 +160,28 @@ class PGVectorIndex(EmbeddingIndex): cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") -class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): - def __init__(self, config: PGVectorVectorIOConfig, inference_api: Api.inference) -> None: +class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): + def __init__( + self, + config: PGVectorVectorIOConfig, + inference_api: Api.inference, + files_api: Files | None = None, + ) -> None: self.config = config self.inference_api = inference_api self.conn = None self.cache = {} + self.files_api = files_api + self.kvstore: KVStore | None = None + self.vector_db_store = None + self.openai_vector_store: dict[str, dict[str, Any]] = {} + self.metadatadata_collection_name = "openai_vector_stores_metadata" async def initialize(self) -> None: log.info(f"Initializing PGVector memory adapter with config: {self.config}") + self.kvstore = await kvstore_impl(self.config.kvstore) + await self.initialize_openai_vector_stores() + try: self.conn = psycopg2.connect( host=self.config.host, @@ -201,14 +216,31 @@ class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): log.info("Connection to PGVector database server closed") async def register_vector_db(self, vector_db: VectorDB) -> None: + # Persist vector DB metadata in the KV store + assert self.kvstore is not None + key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}" + await self.kvstore.set(key=key, value=vector_db.model_dump_json()) + + # Upsert model metadata in Postgres upsert_models(self.conn, [(vector_db.identifier, vector_db)]) - index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn) - self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) + # Create and cache the PGVector index table for the vector DB + index = VectorDBWithIndex( + vector_db, + index=PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn, kvstore=self.kvstore), + inference_api=self.inference_api, + ) + self.cache[vector_db.identifier] = index async def unregister_vector_db(self, vector_db_id: str) -> None: - await self.cache[vector_db_id].index.delete() - del self.cache[vector_db_id] + # Remove provider index and cache + if vector_db_id in self.cache: + await self.cache[vector_db_id].index.delete() + del self.cache[vector_db_id] + + # Delete vector DB metadata from KV store + assert self.kvstore is not None + await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_db_id}") async def insert_chunks( self, @@ -237,107 +269,20 @@ class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api) return self.cache[vector_db_id] - async def openai_create_vector_store( - self, - name: str, - file_ids: list[str] | None = None, - expires_after: dict[str, Any] | None = None, - chunking_strategy: dict[str, Any] | None = None, - metadata: dict[str, Any] | None = None, - embedding_model: str | None = None, - embedding_dimension: int | None = 384, - provider_id: str | None = None, - provider_vector_db_id: str | None = None, - ) -> VectorStoreObject: + # OpenAI Vector Stores File operations are not supported in PGVector + async def _save_openai_vector_store_file( + self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]] + ) -> None: raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector") - async def openai_list_vector_stores( - self, - limit: int | None = 20, - order: str | None = "desc", - after: str | None = None, - before: str | None = None, - ) -> VectorStoreListResponse: + async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]: raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector") - async def openai_retrieve_vector_store( - self, - vector_store_id: str, - ) -> VectorStoreObject: + async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]: raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector") - async def openai_update_vector_store( - self, - vector_store_id: str, - name: str | None = None, - expires_after: dict[str, Any] | None = None, - metadata: dict[str, Any] | None = None, - ) -> VectorStoreObject: + async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None: raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector") - async def openai_delete_vector_store( - self, - vector_store_id: str, - ) -> VectorStoreDeleteResponse: - raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector") - - async def openai_search_vector_store( - self, - vector_store_id: str, - query: str | list[str], - filters: dict[str, Any] | None = None, - max_num_results: int | None = 10, - ranking_options: SearchRankingOptions | None = None, - rewrite_query: bool | None = False, - search_mode: str | None = "vector", - ) -> VectorStoreSearchResponsePage: - raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector") - - async def openai_attach_file_to_vector_store( - self, - vector_store_id: str, - file_id: str, - attributes: dict[str, Any] | None = None, - chunking_strategy: VectorStoreChunkingStrategy | None = None, - ) -> VectorStoreFileObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector") - - async def openai_list_files_in_vector_store( - self, - vector_store_id: str, - limit: int | None = 20, - order: str | None = "desc", - after: str | None = None, - before: str | None = None, - filter: VectorStoreFileStatus | None = None, - ) -> VectorStoreListFilesResponse: - raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector") - - async def openai_retrieve_vector_store_file( - self, - vector_store_id: str, - file_id: str, - ) -> VectorStoreFileObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector") - - async def openai_retrieve_vector_store_file_contents( - self, - vector_store_id: str, - file_id: str, - ) -> VectorStoreFileContentsResponse: - raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector") - - async def openai_update_vector_store_file( - self, - vector_store_id: str, - file_id: str, - attributes: dict[str, Any] | None = None, - ) -> VectorStoreFileObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector") - - async def openai_delete_vector_store_file( - self, - vector_store_id: str, - file_id: str, - ) -> VectorStoreFileObject: + async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None: raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector") diff --git a/llama_stack/providers/remote/vector_io/weaviate/config.py b/llama_stack/providers/remote/vector_io/weaviate/config.py index a8c6e3e2c..4283b8d3b 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/config.py +++ b/llama_stack/providers/remote/vector_io/weaviate/config.py @@ -6,15 +6,26 @@ from typing import Any -from pydantic import BaseModel +from pydantic import BaseModel, Field + +from llama_stack.providers.utils.kvstore.config import ( + KVStoreConfig, + SqliteKVStoreConfig, +) class WeaviateRequestProviderData(BaseModel): weaviate_api_key: str weaviate_cluster_url: str + kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None) class WeaviateVectorIOConfig(BaseModel): @classmethod - def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]: - return {} + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: + return { + "kvstore": SqliteKVStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, + db_name="weaviate_registry.db", + ), + } diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index c63dd70c6..35bb40454 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -14,10 +14,13 @@ from weaviate.classes.init import Auth from weaviate.classes.query import Filter from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.files.files import Files from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate +from llama_stack.providers.utils.kvstore import kvstore_impl +from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, VectorDBWithIndex, @@ -27,11 +30,19 @@ from .config import WeaviateRequestProviderData, WeaviateVectorIOConfig log = logging.getLogger(__name__) +VERSION = "v3" +VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::" +VECTOR_INDEX_PREFIX = f"vector_index:weaviate:{VERSION}::" +OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:weaviate:{VERSION}::" +OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:weaviate:{VERSION}::" +OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:weaviate:{VERSION}::" + class WeaviateIndex(EmbeddingIndex): - def __init__(self, client: weaviate.Client, collection_name: str): + def __init__(self, client: weaviate.Client, collection_name: str, kvstore: KVStore | None = None): self.client = client self.collection_name = collection_name + self.kvstore = kvstore async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): assert len(chunks) == len(embeddings), ( @@ -109,11 +120,21 @@ class WeaviateVectorIOAdapter( NeedsRequestProviderData, VectorDBsProtocolPrivate, ): - def __init__(self, config: WeaviateVectorIOConfig, inference_api: Api.inference) -> None: + def __init__( + self, + config: WeaviateVectorIOConfig, + inference_api: Api.inference, + files_api: Files | None, + ) -> None: self.config = config self.inference_api = inference_api self.client_cache = {} self.cache = {} + self.files_api = files_api + self.kvstore: KVStore | None = None + self.vector_db_store = None + self.openai_vector_stores: dict[str, dict[str, Any]] = {} + self.metadata_collection_name = "openai_vector_stores_metadata" def _get_client(self) -> weaviate.Client: provider_data = self.get_request_provider_data() @@ -132,7 +153,26 @@ class WeaviateVectorIOAdapter( return client async def initialize(self) -> None: - pass + """Set up KV store and load existing vector DBs and OpenAI vector stores.""" + # Initialize KV store for metadata + self.kvstore = await kvstore_impl(self.config.kvstore) + + # Load existing vector DB definitions + start_key = VECTOR_DBS_PREFIX + end_key = f"{VECTOR_DBS_PREFIX}\xff" + stored = await self.kvstore.values_in_range(start_key, end_key) + for raw in stored: + vector_db = VectorDB.model_validate_json(raw) + client = self._get_client() + idx = WeaviateIndex(client=client, collection_name=vector_db.identifier, kvstore=self.kvstore) + self.cache[vector_db.identifier] = VectorDBWithIndex( + vector_db=vector_db, + index=idx, + inference_api=self.inference_api, + ) + + # Load OpenAI vector stores metadata into cache + await self.initialize_openai_vector_stores() async def shutdown(self) -> None: for client in self.client_cache.values(): @@ -206,3 +246,21 @@ class WeaviateVectorIOAdapter( raise ValueError(f"Vector DB {vector_db_id} not found") return await index.query_chunks(query, params) + + # OpenAI Vector Stores File operations are not supported in Weaviate + async def _save_openai_vector_store_file( + self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]] + ) -> None: + raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate") + + async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]: + raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate") + + async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]: + raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate") + + async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None: + raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate") + + async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None: + raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate") diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index 7c97ff7f6..27bb1c997 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import asyncio +import json import logging import mimetypes import time @@ -35,6 +36,7 @@ from llama_stack.apis.vector_io import ( VectorStoreSearchResponse, VectorStoreSearchResponsePage, ) +from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks logger = logging.getLogger(__name__) @@ -59,26 +61,45 @@ class OpenAIVectorStoreMixin(ABC): # These should be provided by the implementing class openai_vector_stores: dict[str, dict[str, Any]] files_api: Files | None + # KV store for persisting OpenAI vector store metadata + kvstore: KVStore | None - @abstractmethod async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: """Save vector store metadata to persistent storage.""" - pass + assert self.kvstore is not None + key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" + await self.kvstore.set(key=key, value=json.dumps(store_info)) + # update in-memory cache + self.openai_vector_stores[store_id] = store_info - @abstractmethod async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]: """Load all vector store metadata from persistent storage.""" - pass + assert self.kvstore is not None + start_key = OPENAI_VECTOR_STORES_PREFIX + end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff" + stored_data = await self.kvstore.values_in_range(start_key, end_key) + + stores: dict[str, dict[str, Any]] = {} + for item in stored_data: + info = json.loads(item) + stores[info["id"]] = info + return stores - @abstractmethod async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: """Update vector store metadata in persistent storage.""" - pass + assert self.kvstore is not None + key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" + await self.kvstore.set(key=key, value=json.dumps(store_info)) + # update in-memory cache + self.openai_vector_stores[store_id] = store_info - @abstractmethod async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None: """Delete vector store metadata from persistent storage.""" - pass + assert self.kvstore is not None + key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" + await self.kvstore.delete(key) + # remove from in-memory cache + self.openai_vector_stores.pop(store_id, None) @abstractmethod async def _save_openai_vector_store_file( @@ -117,6 +138,10 @@ class OpenAIVectorStoreMixin(ABC): """Unregister a vector database (provider-specific implementation).""" pass + async def initialize_openai_vector_stores(self) -> None: + """Load existing OpenAI vector stores into the in-memory cache.""" + self.openai_vector_stores = await self._load_openai_vector_stores() + @abstractmethod async def insert_chunks( self, diff --git a/llama_stack/templates/open-benchmark/open_benchmark.py b/llama_stack/templates/open-benchmark/open_benchmark.py index 56ee9c47d..63a27e07f 100644 --- a/llama_stack/templates/open-benchmark/open_benchmark.py +++ b/llama_stack/templates/open-benchmark/open_benchmark.py @@ -128,6 +128,7 @@ def get_distribution_template() -> DistributionTemplate: provider_id="${env.ENABLE_PGVECTOR:+pgvector}", provider_type="remote::pgvector", config=PGVectorVectorIOConfig.sample_run_config( + f"~/.llama/distributions/{name}", db="${env.PGVECTOR_DB:=}", user="${env.PGVECTOR_USER:=}", password="${env.PGVECTOR_PASSWORD:=}", diff --git a/llama_stack/templates/open-benchmark/run.yaml b/llama_stack/templates/open-benchmark/run.yaml index 0b368ebc9..7d07cc4bf 100644 --- a/llama_stack/templates/open-benchmark/run.yaml +++ b/llama_stack/templates/open-benchmark/run.yaml @@ -54,6 +54,9 @@ providers: db: ${env.PGVECTOR_DB:=} user: ${env.PGVECTOR_USER:=} password: ${env.PGVECTOR_PASSWORD:=} + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/pgvector_registry.db safety: - provider_id: llama-guard provider_type: inline::llama-guard diff --git a/llama_stack/templates/starter/run.yaml b/llama_stack/templates/starter/run.yaml index ad449cb1b..8e20f5224 100644 --- a/llama_stack/templates/starter/run.yaml +++ b/llama_stack/templates/starter/run.yaml @@ -166,6 +166,9 @@ providers: db: ${env.PGVECTOR_DB:=} user: ${env.PGVECTOR_USER:=} password: ${env.PGVECTOR_PASSWORD:=} + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db files: - provider_id: meta-reference-files provider_type: inline::localfs diff --git a/llama_stack/templates/starter/starter.py b/llama_stack/templates/starter/starter.py index c0ac44183..f6ca73028 100644 --- a/llama_stack/templates/starter/starter.py +++ b/llama_stack/templates/starter/starter.py @@ -241,6 +241,7 @@ def get_distribution_template() -> DistributionTemplate: provider_id="${env.ENABLE_PGVECTOR:=__disabled__}", provider_type="remote::pgvector", config=PGVectorVectorIOConfig.sample_run_config( + f"~/.llama/distributions/{name}", db="${env.PGVECTOR_DB:=}", user="${env.PGVECTOR_USER:=}", password="${env.PGVECTOR_PASSWORD:=}", From 4ae5656c2f0668f26d82dc6476a5272db119e177 Mon Sep 17 00:00:00 2001 From: Varsha Date: Mon, 14 Jul 2025 16:39:55 -0700 Subject: [PATCH 12/12] feat: Implement keyword search in milvus (#2231) # What does this PR do? This PR adds the keyword search implementation for Milvus. Along with the implementation for remote Milvus, the tests require us to start a Milvus containers locally. In order to verify the implementation, run: ``` pytest tests/unit/providers/vector_io/remote/test_milvus.py -v -s --tb=short --disable-warnings --asyncio-mode=auto ``` You can also test the changes using the below script: ``` #!/usr/bin/env python3 import asyncio import os import uuid from typing import List from llama_stack_client import ( Agent, AgentEventLogger, LlamaStackClient, RAGDocument ) class MilvusRAGDemo: def __init__(self, base_url: str = "http://localhost:8321/"): self.client = LlamaStackClient(base_url=base_url) self.vector_db_id = f"milvus_rag_demo_{uuid.uuid4().hex[:8]}" self.model_id = None self.embedding_model_id = None self.embedding_dimension = None def setup_models(self): """Get available models and select appropriate ones for LLM and embeddings.""" models = self.client.models.list() # Select embedding model embedding_models = [m for m in models if m.model_type == "embedding"] if not embedding_models: raise ValueError("No embedding models found") self.embedding_model_id = embedding_models[0].identifier self.embedding_dimension = embedding_models[0].metadata["embedding_dimension"] def register_vector_db(self): print(f"Registering Milvus vector database: {self.vector_db_id}") response = self.client.vector_dbs.register( vector_db_id=self.vector_db_id, embedding_model=self.embedding_model_id, embedding_dimension=self.embedding_dimension, provider_id="milvus-remote", # Use remote Milvus ) print(f"Vector database registered successfully") return response def insert_documents(self): """Insert sample documents into the vector database.""" print("\nInserting sample documents...") # Sample documents about different topics documents = [ RAGDocument( document_id="ai_ml_basics", content=""" Artificial Intelligence (AI) and Machine Learning (ML) are transforming the world. AI refers to the simulation of human intelligence in machines, while ML is a subset of AI that enables computers to learn and improve from experience without being explicitly programmed. Deep learning, a subset of ML, uses neural networks with multiple layers to process complex patterns in data. Key concepts in AI/ML include: - Supervised Learning: Training with labeled data - Unsupervised Learning: Finding patterns in unlabeled data - Reinforcement Learning: Learning through trial and error - Neural Networks: Computing systems inspired by biological brains """, mime_type="text/plain", metadata={"topic": "technology", "category": "ai_ml"}, ), ] # Insert documents with chunking self.client.tool_runtime.rag_tool.insert( documents=documents, vector_db_id=self.vector_db_id, chunk_size_in_tokens=200, # Smaller chunks for better granularity ) print(f"Inserted {len(documents)} documents with chunking") def test_keyword_search(self): """Test keyword-based search using BM25.""" queries = [ "neural networks", "Python frameworks", "data cleaning", ] for query in queries: response = self.client.vector_io.query( vector_db_id=self.vector_db_id, query=query, params={ "mode": "keyword", # Keyword search "max_chunks": 3, "score_threshold": 0.0, } ) for i, (chunk, score) in enumerate(zip(response.chunks, response.scores)): print(f" {i+1}. Score: {score:.4f}") print(f" Content: {chunk.content[:100]}...") print(f" Metadata: {chunk.metadata}") def run_demo(self): try: self.setup_models() self.register_vector_db() self.insert_documents() self.test_keyword_search() except Exception as e: print(f"Error during demo: {e}") raise def main(): """Main function to run the demo.""" # Check if Llama Stack server is running demo = MilvusRAGDemo() try: demo.run_demo() except Exception as e: print(f"Demo failed: {e}") if __name__ == "__main__": main() ``` [//]: # (## Documentation) --------- Signed-off-by: Varsha Prasad Narsing --- .../providers/vector_io/remote_milvus.md | 5 +- .../remote/vector_io/milvus/config.py | 13 +- .../remote/vector_io/milvus/milvus.py | 130 +++++++++++- .../providers/vector_io/remote/test_milvus.py | 191 ++++++++++++++++++ 4 files changed, 331 insertions(+), 8 deletions(-) create mode 100644 tests/unit/providers/vector_io/remote/test_milvus.py diff --git a/docs/source/providers/vector_io/remote_milvus.md b/docs/source/providers/vector_io/remote_milvus.md index f3089e615..6734d8315 100644 --- a/docs/source/providers/vector_io/remote_milvus.md +++ b/docs/source/providers/vector_io/remote_milvus.md @@ -114,7 +114,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi | `uri` | `` | No | PydanticUndefined | The URI of the Milvus server | | `token` | `str \| None` | No | PydanticUndefined | The token of the Milvus server | | `consistency_level` | `` | No | Strong | The consistency level of the Milvus server | -| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) | +| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend | | `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. | > **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider. @@ -124,6 +124,9 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi ```yaml uri: ${env.MILVUS_ENDPOINT} token: ${env.MILVUS_TOKEN} +kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/milvus_remote_registry.db ``` diff --git a/llama_stack/providers/remote/vector_io/milvus/config.py b/llama_stack/providers/remote/vector_io/milvus/config.py index e3f51b4f4..899d3678d 100644 --- a/llama_stack/providers/remote/vector_io/milvus/config.py +++ b/llama_stack/providers/remote/vector_io/milvus/config.py @@ -8,7 +8,7 @@ from typing import Any from pydantic import BaseModel, ConfigDict, Field -from llama_stack.providers.utils.kvstore.config import KVStoreConfig +from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.schema_utils import json_schema_type @@ -17,7 +17,7 @@ class MilvusVectorIOConfig(BaseModel): uri: str = Field(description="The URI of the Milvus server") token: str | None = Field(description="The token of the Milvus server") consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong") - kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None) + kvstore: KVStoreConfig = Field(description="Config for KV store backend") # This configuration allows additional fields to be passed through to the underlying Milvus client. # See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. @@ -25,4 +25,11 @@ class MilvusVectorIOConfig(BaseModel): @classmethod def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: - return {"uri": "${env.MILVUS_ENDPOINT}", "token": "${env.MILVUS_TOKEN}"} + return { + "uri": "${env.MILVUS_ENDPOINT}", + "token": "${env.MILVUS_TOKEN}", + "kvstore": SqliteKVStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, + db_name="milvus_remote_registry.db", + ), + } diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index d88d954ef..f301942cb 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -12,7 +12,7 @@ import re from typing import Any from numpy.typing import NDArray -from pymilvus import DataType, MilvusClient +from pymilvus import DataType, Function, FunctionType, MilvusClient from llama_stack.apis.files.files import Files from llama_stack.apis.inference import Inference, InterleavedContent @@ -74,12 +74,66 @@ class MilvusIndex(EmbeddingIndex): assert len(chunks) == len(embeddings), ( f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" ) + if not await asyncio.to_thread(self.client.has_collection, self.collection_name): + logger.info(f"Creating new collection {self.collection_name} with nullable sparse field") + # Create schema for vector search + schema = self.client.create_schema() + schema.add_field( + field_name="chunk_id", + datatype=DataType.VARCHAR, + is_primary=True, + max_length=100, + ) + schema.add_field( + field_name="content", + datatype=DataType.VARCHAR, + max_length=65535, + enable_analyzer=True, # Enable text analysis for BM25 + ) + schema.add_field( + field_name="vector", + datatype=DataType.FLOAT_VECTOR, + dim=len(embeddings[0]), + ) + schema.add_field( + field_name="chunk_content", + datatype=DataType.JSON, + ) + # Add sparse vector field for BM25 (required by the function) + schema.add_field( + field_name="sparse", + datatype=DataType.SPARSE_FLOAT_VECTOR, + ) + + # Create indexes + index_params = self.client.prepare_index_params() + index_params.add_index( + field_name="vector", + index_type="FLAT", + metric_type="COSINE", + ) + # Add index for sparse field (required by BM25 function) + index_params.add_index( + field_name="sparse", + index_type="SPARSE_INVERTED_INDEX", + metric_type="BM25", + ) + + # Add BM25 function for full-text search + bm25_function = Function( + name="text_bm25_emb", + input_field_names=["content"], + output_field_names=["sparse"], + function_type=FunctionType.BM25, + ) + schema.add_function(bm25_function) + await asyncio.to_thread( self.client.create_collection, self.collection_name, - dimension=len(embeddings[0]), - auto_id=True, + schema=schema, + index_params=index_params, consistency_level=self.consistency_level, ) @@ -88,8 +142,10 @@ class MilvusIndex(EmbeddingIndex): data.append( { "chunk_id": chunk.chunk_id, + "content": chunk.content, "vector": embedding, "chunk_content": chunk.model_dump(), + # sparse field will be handled by BM25 function automatically } ) try: @@ -107,6 +163,7 @@ class MilvusIndex(EmbeddingIndex): self.client.search, collection_name=self.collection_name, data=[embedding], + anns_field="vector", limit=k, output_fields=["*"], search_params={"params": {"radius": score_threshold}}, @@ -121,7 +178,64 @@ class MilvusIndex(EmbeddingIndex): k: int, score_threshold: float, ) -> QueryChunksResponse: - raise NotImplementedError("Keyword search is not supported in Milvus") + """ + Perform BM25-based keyword search using Milvus's built-in full-text search. + """ + try: + # Use Milvus's built-in BM25 search + search_res = await asyncio.to_thread( + self.client.search, + collection_name=self.collection_name, + data=[query_string], # Raw text query + anns_field="sparse", # Use sparse field for BM25 + output_fields=["chunk_content"], # Output the chunk content + limit=k, + search_params={ + "params": { + "drop_ratio_search": 0.2, # Ignore low-importance terms + } + }, + ) + + chunks = [] + scores = [] + for res in search_res[0]: + chunk = Chunk(**res["entity"]["chunk_content"]) + chunks.append(chunk) + scores.append(res["distance"]) # BM25 score from Milvus + + # Filter by score threshold + filtered_chunks = [chunk for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold] + filtered_scores = [score for score in scores if score >= score_threshold] + + return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores) + + except Exception as e: + logger.error(f"Error performing BM25 search: {e}") + # Fallback to simple text search + return await self._fallback_keyword_search(query_string, k, score_threshold) + + async def _fallback_keyword_search( + self, + query_string: str, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + """ + Fallback to simple text search when BM25 search is not available. + """ + # Simple text search using content field + search_res = await asyncio.to_thread( + self.client.query, + collection_name=self.collection_name, + filter='content like "%{content}%"', + filter_params={"content": query_string}, + output_fields=["*"], + limit=k, + ) + chunks = [Chunk(**res["chunk_content"]) for res in search_res] + scores = [1.0] * len(chunks) # Simple binary score for text search + return QueryChunksResponse(chunks=chunks, scores=scores) async def query_hybrid( self, @@ -247,6 +361,14 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP if not index: raise ValueError(f"Vector DB {vector_db_id} not found") + if params and params.get("mode") == "keyword": + # Check if this is inline Milvus (Milvus-Lite) + if hasattr(self.config, "db_path"): + raise NotImplementedError( + "Keyword search is not supported in Milvus-Lite. " + "Please use a remote Milvus server for keyword search functionality." + ) + return await index.query_chunks(query, params) async def _save_openai_vector_store_file( diff --git a/tests/unit/providers/vector_io/remote/test_milvus.py b/tests/unit/providers/vector_io/remote/test_milvus.py new file mode 100644 index 000000000..2f212e374 --- /dev/null +++ b/tests/unit/providers/vector_io/remote/test_milvus.py @@ -0,0 +1,191 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +import pytest_asyncio + +from llama_stack.apis.vector_io import QueryChunksResponse + +# Mock the entire pymilvus module +pymilvus_mock = MagicMock() +pymilvus_mock.DataType = MagicMock() +pymilvus_mock.MilvusClient = MagicMock + +# Apply the mock before importing MilvusIndex +with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}): + from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex + +# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain +# tests which are specific to this class. More general (API-level) tests should be placed in +# tests/integration/vector_io/ +# +# How to run this test: +# +# pytest tests/unit/providers/vector_io/test_milvus.py \ +# -v -s --tb=short --disable-warnings --asyncio-mode=auto + +MILVUS_PROVIDER = "milvus" + + +@pytest_asyncio.fixture +async def mock_milvus_client() -> MagicMock: + """Create a mock Milvus client with common method behaviors.""" + client = MagicMock() + + # Mock collection operations + client.has_collection.return_value = False # Initially no collection + client.create_collection.return_value = None + client.drop_collection.return_value = None + + # Mock insert operation + client.insert.return_value = {"insert_count": 10} + + # Mock search operation - return mock results (data should be dict, not JSON string) + client.search.return_value = [ + [ + { + "id": 0, + "distance": 0.1, + "entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}}, + }, + { + "id": 1, + "distance": 0.2, + "entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}}, + }, + ] + ] + + # Mock query operation for keyword search (data should be dict, not JSON string) + client.query.return_value = [ + { + "chunk_id": "chunk1", + "chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}, + "score": 0.9, + }, + { + "chunk_id": "chunk2", + "chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}, + "score": 0.8, + }, + { + "chunk_id": "chunk3", + "chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}}, + "score": 0.7, + }, + ] + + return client + + +@pytest_asyncio.fixture +async def milvus_index(mock_milvus_client): + """Create a MilvusIndex with mocked client.""" + index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection") + yield index + # No real cleanup needed since we're using mocks + + +@pytest.mark.asyncio +async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): + # Setup: collection doesn't exist initially, then exists after creation + mock_milvus_client.has_collection.side_effect = [False, True] + + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Verify collection was created and data was inserted + mock_milvus_client.create_collection.assert_called_once() + mock_milvus_client.insert.assert_called_once() + + # Verify the insert call had the right number of chunks + insert_call = mock_milvus_client.insert.call_args + assert len(insert_call[1]["data"]) == len(sample_chunks) + + +@pytest.mark.asyncio +async def test_query_chunks_vector( + milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client +): + # Setup: Add chunks first + mock_milvus_client.has_collection.return_value = True + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Test vector search + query_embedding = np.random.rand(embedding_dimension).astype(np.float32) + response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 2 + mock_milvus_client.search.assert_called_once() + + +@pytest.mark.asyncio +async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): + mock_milvus_client.has_collection.return_value = True + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Test keyword search + query_string = "Sentence 5" + response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 2 + + +@pytest.mark.asyncio +async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): + """Test that when BM25 search fails, the system falls back to simple text search.""" + mock_milvus_client.has_collection.return_value = True + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Force BM25 search to fail + mock_milvus_client.search.side_effect = Exception("BM25 search not available") + + # Mock simple text search results + mock_milvus_client.query.return_value = [ + { + "chunk_id": "chunk1", + "chunk_content": {"content": "Python programming language", "metadata": {"document_id": "doc1"}}, + }, + { + "chunk_id": "chunk2", + "chunk_content": {"content": "Machine learning algorithms", "metadata": {"document_id": "doc2"}}, + }, + ] + + # Test keyword search that should fall back to simple text search + query_string = "Python" + response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0) + + # Verify response structure + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) > 0, "Fallback search should return results" + + # Verify that simple text search was used (query method called instead of search) + mock_milvus_client.query.assert_called_once() + mock_milvus_client.search.assert_called_once() # Called once but failed + + # Verify the query uses parameterized filter with filter_params + query_call_args = mock_milvus_client.query.call_args + assert "filter" in query_call_args[1], "Query should include filter for text search" + assert "filter_params" in query_call_args[1], "Query should use parameterized filter" + assert query_call_args[1]["filter_params"]["content"] == "Python", "Filter params should contain the search term" + + # Verify all returned chunks have score 1.0 (simple binary scoring) + assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring" + + +@pytest.mark.asyncio +async def test_delete_collection(milvus_index, mock_milvus_client): + # Test collection deletion + mock_milvus_client.has_collection.return_value = True + + await milvus_index.delete() + + mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)