From aefbb6f9eacbc586f26901e6a736c29aa3aa9eb2 Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Mon, 3 Nov 2025 14:48:52 -0500 Subject: [PATCH] feat: Adding optional embeddings to content Signed-off-by: Francisco Javier Arceo --- client-sdks/stainless/openapi.yml | 44 ++ docs/static/llama-stack-spec.yaml | 44 ++ docs/static/stainless-llama-stack-spec.yaml | 44 ++ src/llama_stack/apis/vector_io/vector_io.py | 32 +- src/llama_stack/core/library_client.py | 6 + src/llama_stack/core/routers/vector_io.py | 24 +- .../core/routing_tables/vector_stores.py | 9 +- .../core/server/query_params_middleware.py | 49 +++ src/llama_stack/core/server/server.py | 8 + .../utils/memory/openai_vector_store_mixin.py | 75 ++-- .../app/logs/vector-stores/page.tsx | 386 +++++++++++++++--- .../components/prompts/prompt-editor.test.tsx | 2 +- .../vector-store-detail.test.tsx | 14 + .../vector-stores/vector-store-detail.tsx | 183 ++++++++- .../vector-stores/vector-store-editor.tsx | 235 +++++++++++ src/llama_stack_ui/lib/contents-api.ts | 40 +- .../vector_io/test_openai_vector_stores.py | 95 +++++ tests/unit/core/routers/test_vector_io.py | 62 +++ .../server/test_query_params_middleware.py | 86 ++++ tests/unit/server/test_sse.py | 8 +- 20 files changed, 1314 insertions(+), 132 deletions(-) create mode 100644 src/llama_stack/core/server/query_params_middleware.py create mode 100644 src/llama_stack_ui/components/vector-stores/vector-store-editor.tsx create mode 100644 tests/unit/server/test_query_params_middleware.py diff --git a/client-sdks/stainless/openapi.yml b/client-sdks/stainless/openapi.yml index 9f3ef15b5..020f3840e 100644 --- a/client-sdks/stainless/openapi.yml +++ b/client-sdks/stainless/openapi.yml @@ -2691,7 +2691,12 @@ paths: responses: '200': description: >- +<<<<<<< HEAD A VectorStoreFileContentResponse representing the file contents. +======= + File contents, optionally with embeddings and metadata based on extra_query + parameters. +>>>>>>> 639f0daa (feat: Adding optional embeddings to content) content: application/json: schema: @@ -2726,6 +2731,23 @@ paths: required: true schema: type: string + - name: extra_query + in: query + description: >- + Optional extra parameters to control response format. Set include_embeddings=true + to include embedding vectors. Set include_metadata=true to include chunk + metadata. + required: false + schema: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object deprecated: false /v1/vector_stores/{vector_store_id}/search: post: @@ -10102,6 +10124,28 @@ components: text: type: string description: The actual text content + embedding: + type: array + items: + type: number + description: >- + Optional embedding vector for this content chunk (when requested via extra_body) + chunk_metadata: + $ref: '#/components/schemas/ChunkMetadata' + description: >- + Optional chunk metadata (when requested via extra_body) + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: >- + Optional user-defined metadata (when requested via extra_body) additionalProperties: false required: - type diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index ce8708b68..0888dd586 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -2688,7 +2688,12 @@ paths: responses: '200': description: >- +<<<<<<< HEAD A VectorStoreFileContentResponse representing the file contents. +======= + File contents, optionally with embeddings and metadata based on extra_query + parameters. +>>>>>>> 639f0daa (feat: Adding optional embeddings to content) content: application/json: schema: @@ -2723,6 +2728,23 @@ paths: required: true schema: type: string + - name: extra_query + in: query + description: >- + Optional extra parameters to control response format. Set include_embeddings=true + to include embedding vectors. Set include_metadata=true to include chunk + metadata. + required: false + schema: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object deprecated: false /v1/vector_stores/{vector_store_id}/search: post: @@ -9386,6 +9408,28 @@ components: text: type: string description: The actual text content + embedding: + type: array + items: + type: number + description: >- + Optional embedding vector for this content chunk (when requested via extra_body) + chunk_metadata: + $ref: '#/components/schemas/ChunkMetadata' + description: >- + Optional chunk metadata (when requested via extra_body) + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: >- + Optional user-defined metadata (when requested via extra_body) additionalProperties: false required: - type diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index 9f3ef15b5..020f3840e 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -2691,7 +2691,12 @@ paths: responses: '200': description: >- +<<<<<<< HEAD A VectorStoreFileContentResponse representing the file contents. +======= + File contents, optionally with embeddings and metadata based on extra_query + parameters. +>>>>>>> 639f0daa (feat: Adding optional embeddings to content) content: application/json: schema: @@ -2726,6 +2731,23 @@ paths: required: true schema: type: string + - name: extra_query + in: query + description: >- + Optional extra parameters to control response format. Set include_embeddings=true + to include embedding vectors. Set include_metadata=true to include chunk + metadata. + required: false + schema: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object deprecated: false /v1/vector_stores/{vector_store_id}/search: post: @@ -10102,6 +10124,28 @@ components: text: type: string description: The actual text content + embedding: + type: array + items: + type: number + description: >- + Optional embedding vector for this content chunk (when requested via extra_body) + chunk_metadata: + $ref: '#/components/schemas/ChunkMetadata' + description: >- + Optional chunk metadata (when requested via extra_body) + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: >- + Optional user-defined metadata (when requested via extra_body) additionalProperties: false required: - type diff --git a/src/llama_stack/apis/vector_io/vector_io.py b/src/llama_stack/apis/vector_io/vector_io.py index 846c6f191..fdc7a36db 100644 --- a/src/llama_stack/apis/vector_io/vector_io.py +++ b/src/llama_stack/apis/vector_io/vector_io.py @@ -10,7 +10,7 @@ # the root directory of this source tree. from typing import Annotated, Any, Literal, Protocol, runtime_checkable -from fastapi import Body +from fastapi import Body, Query from pydantic import BaseModel, Field from llama_stack.apis.common.tracing import telemetry_traceable @@ -224,10 +224,16 @@ class VectorStoreContent(BaseModel): :param type: Content type, currently only "text" is supported :param text: The actual text content + :param embedding: Optional embedding vector for this content chunk (when requested via extra_body) + :param chunk_metadata: Optional chunk metadata (when requested via extra_body) + :param metadata: Optional user-defined metadata (when requested via extra_body) """ type: Literal["text"] text: str + embedding: list[float] | None = None + chunk_metadata: ChunkMetadata | None = None + metadata: dict[str, Any] | None = None @json_schema_type @@ -395,22 +401,6 @@ class VectorStoreListFilesResponse(BaseModel): has_more: bool = False -@json_schema_type -class VectorStoreFileContentResponse(BaseModel): - """Represents the parsed content of a vector store file. - - :param object: The object type, which is always `vector_store.file_content.page` - :param data: Parsed content of the file - :param has_more: Indicates if there are more content pages to fetch - :param next_page: The token for the next page, if any - """ - - object: Literal["vector_store.file_content.page"] = "vector_store.file_content.page" - data: list[VectorStoreContent] - has_more: bool - next_page: str | None = None - - @json_schema_type class VectorStoreFileDeleteResponse(BaseModel): """Response from deleting a vector store file. @@ -732,12 +722,16 @@ class VectorIO(Protocol): self, vector_store_id: str, file_id: str, - ) -> VectorStoreFileContentResponse: + include_embeddings: Annotated[bool | None, Query(default=False)] = False, + include_metadata: Annotated[bool | None, Query(default=False)] = False, + ) -> VectorStoreFileContentsResponse: """Retrieves the contents of a vector store file. :param vector_store_id: The ID of the vector store containing the file to retrieve. :param file_id: The ID of the file to retrieve. - :returns: A VectorStoreFileContentResponse representing the file contents. + :param include_embeddings: Whether to include embedding vectors in the response. + :param include_metadata: Whether to include chunk metadata in the response. + :returns: File contents, optionally with embeddings and metadata based on query parameters. """ ... diff --git a/src/llama_stack/core/library_client.py b/src/llama_stack/core/library_client.py index b8f9f715f..db990368b 100644 --- a/src/llama_stack/core/library_client.py +++ b/src/llama_stack/core/library_client.py @@ -389,6 +389,12 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls) body |= path_params + # Pass through params that aren't already handled as path params + if options.params: + extra_query_params = {k: v for k, v in options.params.items() if k not in path_params} + if extra_query_params: + body["extra_query"] = extra_query_params + body, field_names = self._handle_file_uploads(options, body) body = self._convert_body(matched_func, body, exclude_params=set(field_names)) diff --git a/src/llama_stack/core/routers/vector_io.py b/src/llama_stack/core/routers/vector_io.py index 9dac461db..15c7bb5d0 100644 --- a/src/llama_stack/core/routers/vector_io.py +++ b/src/llama_stack/core/routers/vector_io.py @@ -24,7 +24,7 @@ from llama_stack.apis.vector_io import ( VectorStoreChunkingStrategyStaticConfig, VectorStoreDeleteResponse, VectorStoreFileBatchObject, - VectorStoreFileContentResponse, + VectorStoreFileContentsResponse, VectorStoreFileDeleteResponse, VectorStoreFileObject, VectorStoreFilesListInBatchResponse, @@ -247,6 +247,13 @@ class VectorIORouter(VectorIO): metadata: dict[str, Any] | None = None, ) -> VectorStoreObject: logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}") + + # Check if provider_id is being changed (not supported) + if metadata and "provider_id" in metadata: + current_store = await self.routing_table.get_object_by_identifier("vector_store", vector_store_id) + if current_store and current_store.provider_id != metadata["provider_id"]: + raise ValueError("provider_id cannot be changed after vector store creation") + provider = await self.routing_table.get_provider_impl(vector_store_id) return await provider.openai_update_vector_store( vector_store_id=vector_store_id, @@ -338,12 +345,19 @@ class VectorIORouter(VectorIO): self, vector_store_id: str, file_id: str, - ) -> VectorStoreFileContentResponse: - logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}") - provider = await self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_retrieve_vector_store_file_contents( + include_embeddings: bool | None = False, + include_metadata: bool | None = False, + ) -> VectorStoreFileContentsResponse: + logger.debug( + f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}, " + f"include_embeddings={include_embeddings}, include_metadata={include_metadata}" + ) + + return await self.routing_table.openai_retrieve_vector_store_file_contents( vector_store_id=vector_store_id, file_id=file_id, + include_embeddings=include_embeddings, + include_metadata=include_metadata, ) async def openai_update_vector_store_file( diff --git a/src/llama_stack/core/routing_tables/vector_stores.py b/src/llama_stack/core/routing_tables/vector_stores.py index f95a4dbe3..649df934e 100644 --- a/src/llama_stack/core/routing_tables/vector_stores.py +++ b/src/llama_stack/core/routing_tables/vector_stores.py @@ -15,7 +15,7 @@ from llama_stack.apis.vector_io.vector_io import ( SearchRankingOptions, VectorStoreChunkingStrategy, VectorStoreDeleteResponse, - VectorStoreFileContentResponse, + VectorStoreFileContentsResponse, VectorStoreFileDeleteResponse, VectorStoreFileObject, VectorStoreFileStatus, @@ -195,12 +195,17 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl): self, vector_store_id: str, file_id: str, - ) -> VectorStoreFileContentResponse: + include_embeddings: bool | None = False, + include_metadata: bool | None = False, + ) -> VectorStoreFileContentsResponse: await self.assert_action_allowed("read", "vector_store", vector_store_id) + provider = await self.get_provider_impl(vector_store_id) return await provider.openai_retrieve_vector_store_file_contents( vector_store_id=vector_store_id, file_id=file_id, + include_embeddings=include_embeddings, + include_metadata=include_metadata, ) async def openai_update_vector_store_file( diff --git a/src/llama_stack/core/server/query_params_middleware.py b/src/llama_stack/core/server/query_params_middleware.py new file mode 100644 index 000000000..828f401fc --- /dev/null +++ b/src/llama_stack/core/server/query_params_middleware.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import re + +from fastapi import Request +from starlette.middleware.base import BaseHTTPMiddleware + +from llama_stack.log import get_logger + +logger = get_logger(name=__name__, category="core::middleware") + +# Patterns for endpoints that need query parameter injection +QUERY_PARAM_ENDPOINTS = [ + # /vector_stores/{vector_store_id}/files/{file_id}/content + re.compile(r"/vector_stores/[^/]+/files/[^/]+/content$"), +] + + +class QueryParamsMiddleware(BaseHTTPMiddleware): + """Middleware to inject query parameters into extra_query for specific endpoints""" + + async def dispatch(self, request: Request, call_next): + # Check if this is an endpoint that needs query parameter injection + if request.method == "GET" and any(pattern.search(str(request.url.path)) for pattern in QUERY_PARAM_ENDPOINTS): + # Extract all query parameters and convert to appropriate types + extra_query = {} + query_params = dict(request.query_params) + + # Convert query parameters using JSON parsing for robust type conversion + for key, value in query_params.items(): + try: + # parse as JSON to handles booleans, numbers, strings properly + extra_query[key] = json.loads(value) + except (json.JSONDecodeError, ValueError): + # if parsing fails, keep as string + extra_query[key] = value + + if extra_query: + # Store the extra_query in request state so we can access it later + request.state.extra_query = extra_query + logger.debug(f"QueryParamsMiddleware extracted extra_query: {extra_query}") + + response = await call_next(request) + return response diff --git a/src/llama_stack/core/server/server.py b/src/llama_stack/core/server/server.py index 80505c3f9..7066aa3a0 100644 --- a/src/llama_stack/core/server/server.py +++ b/src/llama_stack/core/server/server.py @@ -46,6 +46,7 @@ from llama_stack.core.request_headers import ( request_provider_data_context, user_from_scope, ) +from llama_stack.core.server.query_params_middleware import QueryParamsMiddleware from llama_stack.core.server.routes import get_all_api_routes from llama_stack.core.stack import ( Stack, @@ -263,6 +264,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: is_streaming = is_streaming_request(func.__name__, request, **kwargs) + # Inject extra_query from middleware if available + if hasattr(request.state, "extra_query") and request.state.extra_query: + kwargs["extra_query"] = request.state.extra_query + try: if is_streaming: context_vars = [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR] @@ -402,6 +407,9 @@ def create_app() -> StackApp: if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"): app.add_middleware(ClientVersionMiddleware) + # handle extra_query for specific GET requests + app.add_middleware(QueryParamsMiddleware) + impls = app.stack.impls if config.server.auth: diff --git a/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index 86e6ea013..ab376be70 100644 --- a/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import ( VectorStoreContent, VectorStoreDeleteResponse, VectorStoreFileBatchObject, - VectorStoreFileContentResponse, + VectorStoreFileContentsResponse, VectorStoreFileCounts, VectorStoreFileDeleteResponse, VectorStoreFileLastError, @@ -450,7 +450,7 @@ class OpenAIVectorStoreMixin(ABC): # Now that our vector store is created, attach any files that were provided file_ids = params.file_ids or [] tasks = [self.openai_attach_file_to_vector_store(vector_store_id, file_id) for file_id in file_ids] - await asyncio.gather(*tasks) + await asyncio.gather(*tasks, return_exceptions=True) # Get the updated store info and return it store_info = self.openai_vector_stores[vector_store_id] @@ -704,34 +704,35 @@ class OpenAIVectorStoreMixin(ABC): # Unknown filter type, default to no match raise ValueError(f"Unsupported filter type: {filter_type}") - def _chunk_to_vector_store_content(self, chunk: Chunk) -> list[VectorStoreContent]: - # content is InterleavedContent + def _extract_chunk_fields(self, chunk: Chunk, include_embeddings: bool, include_metadata: bool) -> dict: + """Extract embedding and metadata fields from chunk based on include flags.""" + return { + "embedding": chunk.embedding if include_embeddings else None, + "chunk_metadata": chunk.chunk_metadata if include_metadata else None, + "metadata": chunk.metadata if include_metadata else None, + } + + def _chunk_to_vector_store_content( + self, chunk: Chunk, include_embeddings: bool = False, include_metadata: bool = False + ) -> list[VectorStoreContent]: + fields = self._extract_chunk_fields(chunk, include_embeddings, include_metadata) + if isinstance(chunk.content, str): - content = [ - VectorStoreContent( - type="text", - text=chunk.content, - ) - ] + content_item = VectorStoreContent(type="text", text=chunk.content, **fields) + content = [content_item] elif isinstance(chunk.content, list): # TODO: Add support for other types of content - content = [ - VectorStoreContent( - type="text", - text=item.text, - ) - for item in chunk.content - if item.type == "text" - ] + content = [] + for item in chunk.content: + if item.type == "text": + content_item = VectorStoreContent(type="text", text=item.text, **fields) + content.append(content_item) else: if chunk.content.type != "text": raise ValueError(f"Unsupported content type: {chunk.content.type}") - content = [ - VectorStoreContent( - type="text", - text=chunk.content.text, - ) - ] + + content_item = VectorStoreContent(type="text", text=chunk.content.text, **fields) + content = [content_item] return content async def openai_attach_file_to_vector_store( @@ -820,13 +821,12 @@ class OpenAIVectorStoreMixin(ABC): message=str(e), ) - # Create OpenAI vector store file metadata + # Save vector store file to persistent storage AFTER insert_chunks + # so that chunks include the embeddings that were generated file_info = vector_store_file_object.model_dump(exclude={"last_error"}) file_info["filename"] = file_response.filename if file_response else "" - # Save vector store file to persistent storage (provider-specific) dict_chunks = [c.model_dump() for c in chunks] - # This should be updated to include chunk_id await self._save_openai_vector_store_file(vector_store_id, file_id, file_info, dict_chunks) # Update file_ids and file_counts in vector store metadata @@ -921,21 +921,28 @@ class OpenAIVectorStoreMixin(ABC): self, vector_store_id: str, file_id: str, - ) -> VectorStoreFileContentResponse: + include_embeddings: bool | None = False, + include_metadata: bool | None = False, + ) -> VectorStoreFileContentsResponse: """Retrieves the contents of a vector store file.""" if vector_store_id not in self.openai_vector_stores: raise VectorStoreNotFoundError(vector_store_id) + file_info = await self._load_openai_vector_store_file(vector_store_id, file_id) dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id) chunks = [Chunk.model_validate(c) for c in dict_chunks] content = [] for chunk in chunks: - content.extend(self._chunk_to_vector_store_content(chunk)) - return VectorStoreFileContentResponse( - object="vector_store.file_content.page", - data=content, - has_more=False, - next_page=None, + content.extend( + self._chunk_to_vector_store_content( + chunk, include_embeddings=include_embeddings or False, include_metadata=include_metadata or False + ) + ) + return VectorStoreFileContentsResponse( + file_id=file_id, + filename=file_info.get("filename", ""), + attributes=file_info.get("attributes", {}), + content=content, ) async def openai_update_vector_store_file( diff --git a/src/llama_stack_ui/app/logs/vector-stores/page.tsx b/src/llama_stack_ui/app/logs/vector-stores/page.tsx index 72196d496..84680e01a 100644 --- a/src/llama_stack_ui/app/logs/vector-stores/page.tsx +++ b/src/llama_stack_ui/app/logs/vector-stores/page.tsx @@ -8,6 +8,9 @@ import type { import { useRouter } from "next/navigation"; import { usePagination } from "@/hooks/use-pagination"; import { Button } from "@/components/ui/button"; +import { Plus, Trash2, Search, Edit, X } from "lucide-react"; +import { useState } from "react"; +import { Input } from "@/components/ui/input"; import { Table, TableBody, @@ -17,9 +20,21 @@ import { TableRow, } from "@/components/ui/table"; import { Skeleton } from "@/components/ui/skeleton"; +import { useAuthClient } from "@/hooks/use-auth-client"; +import { + VectorStoreEditor, + VectorStoreFormData, +} from "@/components/vector-stores/vector-store-editor"; export default function VectorStoresPage() { const router = useRouter(); + const client = useAuthClient(); + const [deletingStores, setDeletingStores] = useState>(new Set()); + const [searchTerm, setSearchTerm] = useState(""); + const [showVectorStoreModal, setShowVectorStoreModal] = useState(false); + const [editingStore, setEditingStore] = useState(null); + const [modalError, setModalError] = useState(null); + const [showSuccessState, setShowSuccessState] = useState(false); const { data: stores, status, @@ -47,6 +62,142 @@ export default function VectorStoresPage() { } }, [status, hasMore, loadMore]); + // Handle ESC key to close modal + React.useEffect(() => { + const handleEscape = (event: KeyboardEvent) => { + if (event.key === "Escape" && showVectorStoreModal) { + handleCancel(); + } + }; + + document.addEventListener("keydown", handleEscape); + return () => document.removeEventListener("keydown", handleEscape); + }, [showVectorStoreModal]); + + const handleDeleteVectorStore = async (storeId: string) => { + if ( + !confirm( + "Are you sure you want to delete this vector store? This action cannot be undone." + ) + ) { + return; + } + + setDeletingStores(prev => new Set([...prev, storeId])); + + try { + await client.vectorStores.delete(storeId); + // Reload the data to reflect the deletion + window.location.reload(); + } catch (err: unknown) { + console.error("Failed to delete vector store:", err); + const errorMessage = err instanceof Error ? err.message : "Unknown error"; + alert(`Failed to delete vector store: ${errorMessage}`); + } finally { + setDeletingStores(prev => { + const newSet = new Set(prev); + newSet.delete(storeId); + return newSet; + }); + } + }; + + const handleSaveVectorStore = async (formData: VectorStoreFormData) => { + try { + setModalError(null); + + if (editingStore) { + // Update existing vector store + const updateParams: { + name?: string; + extra_body?: Record; + } = {}; + + // Only include fields that have changed or are provided + if (formData.name && formData.name !== editingStore.name) { + updateParams.name = formData.name; + } + + // Add all parameters to extra_body (except provider_id which can't be changed) + const extraBody: Record = {}; + if (formData.embedding_model) { + extraBody.embedding_model = formData.embedding_model; + } + if (formData.embedding_dimension) { + extraBody.embedding_dimension = formData.embedding_dimension; + } + + if (Object.keys(extraBody).length > 0) { + updateParams.extra_body = extraBody; + } + + await client.vectorStores.update(editingStore.id, updateParams); + + // Show success state with close button + setShowSuccessState(true); + setModalError( + "✅ Vector store updated successfully! You can close this modal and refresh the page to see changes." + ); + return; + } + + const createParams: { + name?: string; + provider_id?: string; + extra_body?: Record; + } = { + name: formData.name || undefined, + }; + + // Extract provider_id to top-level (like Python client does) + if (formData.provider_id) { + createParams.provider_id = formData.provider_id; + } + + // Add remaining parameters to extra_body + const extraBody: Record = {}; + if (formData.provider_id) { + extraBody.provider_id = formData.provider_id; + } + if (formData.embedding_model) { + extraBody.embedding_model = formData.embedding_model; + } + if (formData.embedding_dimension) { + extraBody.embedding_dimension = formData.embedding_dimension; + } + + if (Object.keys(extraBody).length > 0) { + createParams.extra_body = extraBody; + } + + await client.vectorStores.create(createParams); + + // Show success state with close button + setShowSuccessState(true); + setModalError( + "✅ Vector store created successfully! You can close this modal and refresh the page to see changes." + ); + } catch (err: unknown) { + console.error("Failed to create vector store:", err); + const errorMessage = + err instanceof Error ? err.message : "Failed to create vector store"; + setModalError(errorMessage); + } + }; + + const handleEditVectorStore = (store: VectorStore) => { + setEditingStore(store); + setShowVectorStoreModal(true); + setModalError(null); + }; + + const handleCancel = () => { + setShowVectorStoreModal(false); + setEditingStore(null); + setModalError(null); + setShowSuccessState(false); + }; + const renderContent = () => { if (status === "loading") { return ( @@ -66,73 +217,190 @@ export default function VectorStoresPage() { return

No vector stores found.

; } - return ( -
- - - - ID - Name - Created - Completed - Cancelled - Failed - In Progress - Total - Usage Bytes - Provider ID - Provider Vector DB ID - - - - {stores.map(store => { - const fileCounts = store.file_counts; - const metadata = store.metadata || {}; - const providerId = metadata.provider_id ?? ""; - const providerDbId = metadata.provider_vector_db_id ?? ""; + // Filter stores based on search term + const filteredStores = stores.filter(store => { + if (!searchTerm) return true; - return ( - router.push(`/logs/vector-stores/${store.id}`)} - className="cursor-pointer hover:bg-muted/50" - > - - - - {store.name} - - {new Date(store.created_at * 1000).toLocaleString()} - - {fileCounts.completed} - {fileCounts.cancelled} - {fileCounts.failed} - {fileCounts.in_progress} - {fileCounts.total} - {store.usage_bytes} - {providerId} - {providerDbId} - - ); - })} - -
+ const searchLower = searchTerm.toLowerCase(); + return ( + store.id.toLowerCase().includes(searchLower) || + (store.name && store.name.toLowerCase().includes(searchLower)) || + (store.metadata?.provider_id && + String(store.metadata.provider_id) + .toLowerCase() + .includes(searchLower)) || + (store.metadata?.provider_vector_db_id && + String(store.metadata.provider_vector_db_id) + .toLowerCase() + .includes(searchLower)) + ); + }); + + return ( +
+ {/* Search Bar */} +
+ + setSearchTerm(e.target.value)} + className="pl-10" + /> +
+ +
+ + + + ID + Name + Created + Completed + Cancelled + Failed + In Progress + Total + Usage Bytes + Provider ID + Provider Vector DB ID + Actions + + + + {filteredStores.map(store => { + const fileCounts = store.file_counts; + const metadata = store.metadata || {}; + const providerId = metadata.provider_id ?? ""; + const providerDbId = metadata.provider_vector_db_id ?? ""; + + return ( + + router.push(`/logs/vector-stores/${store.id}`) + } + className="cursor-pointer hover:bg-muted/50" + > + + + + {store.name} + + {new Date(store.created_at * 1000).toLocaleString()} + + {fileCounts.completed} + {fileCounts.cancelled} + {fileCounts.failed} + {fileCounts.in_progress} + {fileCounts.total} + {store.usage_bytes} + {providerId} + {providerDbId} + +
+ + +
+
+
+ ); + })} +
+
+
); }; return (
-

Vector Stores

+
+

Vector Stores

+ +
{renderContent()} + + {/* Create Vector Store Modal */} + {showVectorStoreModal && ( +
+
+
+

+ {editingStore ? "Edit Vector Store" : "Create New Vector Store"} +

+ +
+
+ +
+
+
+ )}
); } diff --git a/src/llama_stack_ui/components/prompts/prompt-editor.test.tsx b/src/llama_stack_ui/components/prompts/prompt-editor.test.tsx index 458a5f942..70e0e4e66 100644 --- a/src/llama_stack_ui/components/prompts/prompt-editor.test.tsx +++ b/src/llama_stack_ui/components/prompts/prompt-editor.test.tsx @@ -2,7 +2,7 @@ import React from "react"; import { render, screen, fireEvent } from "@testing-library/react"; import "@testing-library/jest-dom"; import { PromptEditor } from "./prompt-editor"; -import type { Prompt, PromptFormData } from "./types"; +import type { Prompt } from "./types"; describe("PromptEditor", () => { const mockOnSave = jest.fn(); diff --git a/src/llama_stack_ui/components/vector-stores/vector-store-detail.test.tsx b/src/llama_stack_ui/components/vector-stores/vector-store-detail.test.tsx index 08f90ac0d..78bec8147 100644 --- a/src/llama_stack_ui/components/vector-stores/vector-store-detail.test.tsx +++ b/src/llama_stack_ui/components/vector-stores/vector-store-detail.test.tsx @@ -12,6 +12,20 @@ jest.mock("next/navigation", () => ({ }), })); +// Mock NextAuth +jest.mock("next-auth/react", () => ({ + useSession: () => ({ + data: { + accessToken: "mock-access-token", + user: { + id: "mock-user-id", + email: "test@example.com", + }, + }, + status: "authenticated", + }), +})); + describe("VectorStoreDetailView", () => { const defaultProps = { store: null, diff --git a/src/llama_stack_ui/components/vector-stores/vector-store-detail.tsx b/src/llama_stack_ui/components/vector-stores/vector-store-detail.tsx index d3d0fa249..f5b6281e7 100644 --- a/src/llama_stack_ui/components/vector-stores/vector-store-detail.tsx +++ b/src/llama_stack_ui/components/vector-stores/vector-store-detail.tsx @@ -1,16 +1,18 @@ "use client"; import { useRouter } from "next/navigation"; +import { useState, useEffect } from "react"; import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores"; import type { VectorStoreFile } from "llama-stack-client/resources/vector-stores/files"; import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card"; import { Skeleton } from "@/components/ui/skeleton"; import { Button } from "@/components/ui/button"; +import { useAuthClient } from "@/hooks/use-auth-client"; +import { Edit2, Trash2, X } from "lucide-react"; import { DetailLoadingView, DetailErrorView, DetailNotFoundView, - DetailLayout, PropertiesCard, PropertyItem, } from "@/components/layout/detail-layout"; @@ -23,6 +25,7 @@ import { TableHeader, TableRow, } from "@/components/ui/table"; +import { VectorStoreEditor, VectorStoreFormData } from "./vector-store-editor"; interface VectorStoreDetailViewProps { store: VectorStore | null; @@ -43,21 +46,122 @@ export function VectorStoreDetailView({ errorFiles, id, }: VectorStoreDetailViewProps) { - const title = "Vector Store Details"; const router = useRouter(); + const client = useAuthClient(); + const [isDeleting, setIsDeleting] = useState(false); + const [showEditModal, setShowEditModal] = useState(false); + const [modalError, setModalError] = useState(null); + const [showSuccessState, setShowSuccessState] = useState(false); + + // Handle ESC key to close modal + useEffect(() => { + const handleEscape = (event: KeyboardEvent) => { + if (event.key === "Escape" && showEditModal) { + handleCancel(); + } + }; + + document.addEventListener("keydown", handleEscape); + return () => document.removeEventListener("keydown", handleEscape); + }, [showEditModal]); const handleFileClick = (fileId: string) => { router.push(`/logs/vector-stores/${id}/files/${fileId}`); }; + const handleEditVectorStore = () => { + setShowEditModal(true); + setModalError(null); + setShowSuccessState(false); + }; + + const handleCancel = () => { + setShowEditModal(false); + setModalError(null); + setShowSuccessState(false); + }; + + const handleSaveVectorStore = async (formData: VectorStoreFormData) => { + try { + setModalError(null); + + // Update existing vector store (same logic as list page) + const updateParams: { + name?: string; + extra_body?: Record; + } = {}; + + // Only include fields that have changed or are provided + if (formData.name && formData.name !== store?.name) { + updateParams.name = formData.name; + } + + // Add all parameters to extra_body (except provider_id which can't be changed) + const extraBody: Record = {}; + if (formData.embedding_model) { + extraBody.embedding_model = formData.embedding_model; + } + if (formData.embedding_dimension) { + extraBody.embedding_dimension = formData.embedding_dimension; + } + + if (Object.keys(extraBody).length > 0) { + updateParams.extra_body = extraBody; + } + + await client.vectorStores.update(id, updateParams); + + // Show success state + setShowSuccessState(true); + setModalError( + "✅ Vector store updated successfully! You can close this modal and refresh the page to see changes." + ); + } catch (err: unknown) { + console.error("Failed to update vector store:", err); + const errorMessage = + err instanceof Error ? err.message : "Failed to update vector store"; + setModalError(errorMessage); + } + }; + + const handleDeleteVectorStore = async () => { + if ( + !confirm( + "Are you sure you want to delete this vector store? This action cannot be undone." + ) + ) { + return; + } + + setIsDeleting(true); + + try { + await client.vectorStores.delete(id); + // Redirect to the vector stores list after successful deletion + router.push("/logs/vector-stores"); + } catch (err: unknown) { + console.error("Failed to delete vector store:", err); + const errorMessage = err instanceof Error ? err.message : "Unknown error"; + alert(`Failed to delete vector store: ${errorMessage}`); + } finally { + setIsDeleting(false); + } + }; + if (errorStore) { - return ; + return ( + + ); } if (isLoadingStore) { - return ; + return ; } if (!store) { - return ; + return ; } const mainContent = ( @@ -138,6 +242,73 @@ export function VectorStoreDetailView({ ); return ( - + <> +
+

Vector Store Details

+
+ + +
+
+
+
{mainContent}
+
{sidebar}
+
+ + {/* Edit Vector Store Modal */} + {showEditModal && ( +
+
+
+

Edit Vector Store

+ +
+
+ +
+
+
+ )} + ); } diff --git a/src/llama_stack_ui/components/vector-stores/vector-store-editor.tsx b/src/llama_stack_ui/components/vector-stores/vector-store-editor.tsx new file mode 100644 index 000000000..719a2a9fd --- /dev/null +++ b/src/llama_stack_ui/components/vector-stores/vector-store-editor.tsx @@ -0,0 +1,235 @@ +"use client"; + +import { useState, useEffect } from "react"; +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { Card, CardContent } from "@/components/ui/card"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { useAuthClient } from "@/hooks/use-auth-client"; +import type { Model } from "llama-stack-client/resources/models"; + +export interface VectorStoreFormData { + name: string; + embedding_model?: string; + embedding_dimension?: number; + provider_id?: string; +} + +interface VectorStoreEditorProps { + onSave: (formData: VectorStoreFormData) => Promise; + onCancel: () => void; + error?: string | null; + initialData?: VectorStoreFormData; + showSuccessState?: boolean; + isEditing?: boolean; +} + +export function VectorStoreEditor({ + onSave, + onCancel, + error, + initialData, + showSuccessState, + isEditing = false, +}: VectorStoreEditorProps) { + const client = useAuthClient(); + const [formData, setFormData] = useState( + initialData || { + name: "", + embedding_model: "", + embedding_dimension: 768, + provider_id: "", + } + ); + const [loading, setLoading] = useState(false); + const [models, setModels] = useState([]); + const [modelsLoading, setModelsLoading] = useState(true); + const [modelsError, setModelsError] = useState(null); + + const embeddingModels = models.filter( + model => model.custom_metadata?.model_type === "embedding" + ); + + useEffect(() => { + const fetchModels = async () => { + try { + setModelsLoading(true); + setModelsError(null); + const modelList = await client.models.list(); + setModels(modelList); + + // Set default embedding model if available + const embeddingModelsList = modelList.filter(model => { + return model.custom_metadata?.model_type === "embedding"; + }); + if (embeddingModelsList.length > 0 && !formData.embedding_model) { + setFormData(prev => ({ + ...prev, + embedding_model: embeddingModelsList[0].id, + })); + } + } catch (err) { + console.error("Failed to load models:", err); + setModelsError( + err instanceof Error ? err.message : "Failed to load models" + ); + } finally { + setModelsLoading(false); + } + }; + + fetchModels(); + }, [client]); + + const handleSubmit = async (e: React.FormEvent) => { + e.preventDefault(); + setLoading(true); + + try { + await onSave(formData); + } finally { + setLoading(false); + } + }; + + return ( + + +
+
+ + setFormData({ ...formData, name: e.target.value })} + placeholder="Enter vector store name" + required + /> +
+ +
+ + {modelsLoading ? ( +
+ Loading models... ({models.length} loaded) +
+ ) : modelsError ? ( +
+ Error: {modelsError} +
+ ) : embeddingModels.length === 0 ? ( +
+ No embedding models available ({models.length} total models) +
+ ) : ( + + )} + {formData.embedding_model && ( +

+ Dimension:{" "} + {embeddingModels.find(m => m.id === formData.embedding_model) + ?.custom_metadata?.embedding_dimension || "Unknown"} +

+ )} +
+ +
+ + + setFormData({ + ...formData, + embedding_dimension: parseInt(e.target.value) || 768, + }) + } + placeholder="768" + /> +
+ +
+ + + setFormData({ ...formData, provider_id: e.target.value }) + } + placeholder="e.g., faiss, chroma, sqlite" + disabled={isEditing} + /> + {isEditing && ( +

+ Provider ID cannot be changed after vector store creation +

+ )} +
+ + {error && ( +
+ {error} +
+ )} + +
+ {showSuccessState ? ( + + ) : ( + <> + + + + )} +
+
+
+
+ ); +} diff --git a/src/llama_stack_ui/lib/contents-api.ts b/src/llama_stack_ui/lib/contents-api.ts index f4920f3db..35456faff 100644 --- a/src/llama_stack_ui/lib/contents-api.ts +++ b/src/llama_stack_ui/lib/contents-api.ts @@ -34,9 +34,35 @@ export class ContentsAPI { async getFileContents( vectorStoreId: string, - fileId: string + fileId: string, + includeEmbeddings: boolean = true, + includeMetadata: boolean = true ): Promise { - return this.client.vectorStores.files.content(vectorStoreId, fileId); + try { + // Use query parameters to pass embeddings and metadata flags (OpenAI-compatible pattern) + const extraQuery: Record = {}; + if (includeEmbeddings) { + extraQuery.include_embeddings = true; + } + if (includeMetadata) { + extraQuery.include_metadata = true; + } + + const result = await this.client.vectorStores.files.content( + vectorStoreId, + fileId, + { + query: { + include_embeddings: includeEmbeddings, + include_metadata: includeMetadata, + }, + } + ); + return result; + } catch (error) { + console.error("ContentsAPI.getFileContents error:", error); + throw error; + } } async getContent( @@ -70,11 +96,15 @@ export class ContentsAPI { order?: string; after?: string; before?: string; + includeEmbeddings?: boolean; + includeMetadata?: boolean; } ): Promise { - const fileContents = await this.client.vectorStores.files.content( + const fileContents = await this.getFileContents( vectorStoreId, - fileId + fileId, + options?.includeEmbeddings ?? true, + options?.includeMetadata ?? true ); const contentItems: VectorStoreContentItem[] = []; @@ -82,7 +112,7 @@ export class ContentsAPI { const rawContent = content as Record; // Extract actual fields from the API response - const embedding = rawContent.embedding || undefined; + const embedding = rawContent.embedding as number[] | undefined; const created_timestamp = rawContent.created_timestamp || rawContent.created_at || diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index 20f9d2978..153a10e93 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -11,6 +11,7 @@ import pytest from llama_stack_client import BadRequestError from openai import BadRequestError as OpenAIBadRequestError +from llama_stack.apis.files import ExpiresAfter from llama_stack.apis.vector_io import Chunk from llama_stack.core.library_client import LlamaStackAsLibraryClient from llama_stack.log import get_logger @@ -1604,3 +1605,97 @@ def test_openai_vector_store_embedding_config_from_metadata( assert "metadata_config_store" in store_names assert "consistent_config_store" in store_names + + +@vector_provider_wrapper +def test_openai_vector_store_file_contents_with_extra_query( + compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id +): + """Test that vector store file contents endpoint supports extra_query parameter.""" + skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) + compat_client = compat_client_with_empty_stores + + # Create a vector store + vector_store = compat_client.vector_stores.create( + name="test_extra_query_store", + extra_body={ + "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, + }, + ) + + # Create and attach a file + test_content = b"This is test content for extra_query validation." + with BytesIO(test_content) as file_buffer: + file_buffer.name = "test_extra_query.txt" + file = compat_client.files.create( + file=file_buffer, + purpose="assistants", + expires_after=ExpiresAfter(anchor="created_at", seconds=86400), + ) + + file_attach_response = compat_client.vector_stores.files.create( + vector_store_id=vector_store.id, + file_id=file.id, + extra_body={"embedding_model": embedding_model_id}, + ) + assert file_attach_response.status == "completed" + + # Wait for processing + time.sleep(2) + + # Test that extra_query parameter is accepted and processed + content_with_extra_query = compat_client.vector_stores.files.content( + vector_store_id=vector_store.id, + file_id=file.id, + extra_query={"include_embeddings": True, "include_metadata": True}, + ) + + # Test without extra_query for comparison + content_without_extra_query = compat_client.vector_stores.files.content( + vector_store_id=vector_store.id, + file_id=file.id, + ) + + # Validate that both calls succeed + assert content_with_extra_query is not None + assert content_without_extra_query is not None + assert len(content_with_extra_query.content) > 0 + assert len(content_without_extra_query.content) > 0 + + # Validate that extra_query parameter is processed correctly + # Both should have the embedding/metadata fields available (may be None based on flags) + first_chunk_with_flags = content_with_extra_query.content[0] + first_chunk_without_flags = content_without_extra_query.content[0] + + # The key validation: extra_query fields are present in the response + # Handle both dict and object responses (different clients may return different formats) + def has_field(obj, field): + if isinstance(obj, dict): + return field in obj + else: + return hasattr(obj, field) + + # Validate that all expected fields are present in both responses + expected_fields = ["embedding", "chunk_metadata", "metadata", "text"] + for field in expected_fields: + assert has_field(first_chunk_with_flags, field), f"Field '{field}' missing from response with extra_query" + assert has_field(first_chunk_without_flags, field), f"Field '{field}' missing from response without extra_query" + + # Validate content is the same + def get_field(obj, field): + if isinstance(obj, dict): + return obj[field] + else: + return getattr(obj, field) + + assert get_field(first_chunk_with_flags, "text") == test_content.decode("utf-8") + assert get_field(first_chunk_without_flags, "text") == test_content.decode("utf-8") + + with_flags_embedding = get_field(first_chunk_with_flags, "embedding") + without_flags_embedding = get_field(first_chunk_without_flags, "embedding") + + # Validate that embeddings are included when requested and excluded when not requested + assert with_flags_embedding is not None, "Embeddings should be included when include_embeddings=True" + assert len(with_flags_embedding) > 0, "Embedding should be a non-empty list" + assert without_flags_embedding is None, "Embeddings should not be included when include_embeddings=False" diff --git a/tests/unit/core/routers/test_vector_io.py b/tests/unit/core/routers/test_vector_io.py index dd3246cb3..f9bd84a37 100644 --- a/tests/unit/core/routers/test_vector_io.py +++ b/tests/unit/core/routers/test_vector_io.py @@ -55,3 +55,65 @@ async def test_create_vector_stores_multiple_providers_missing_provider_id_error with pytest.raises(ValueError, match="Multiple vector_io providers available"): await router.openai_create_vector_store(request) + + +async def test_update_vector_store_provider_id_change_fails(): + """Test that updating a vector store with a different provider_id fails with clear error.""" + mock_routing_table = Mock() + + # Mock an existing vector store with provider_id "faiss" + mock_existing_store = Mock() + mock_existing_store.provider_id = "inline::faiss" + mock_existing_store.identifier = "vs_123" + + mock_routing_table.get_object_by_identifier = AsyncMock(return_value=mock_existing_store) + mock_routing_table.get_provider_impl = AsyncMock( + return_value=Mock(openai_update_vector_store=AsyncMock(return_value=Mock(id="vs_123"))) + ) + + router = VectorIORouter(mock_routing_table) + + # Try to update with different provider_id in metadata - this should fail + with pytest.raises(ValueError, match="provider_id cannot be changed after vector store creation"): + await router.openai_update_vector_store( + vector_store_id="vs_123", + name="updated_name", + metadata={"provider_id": "inline::sqlite"}, # Different provider_id + ) + + # Verify the existing store was looked up to check provider_id + mock_routing_table.get_object_by_identifier.assert_called_once_with("vector_store", "vs_123") + + # Provider should not be called since validation failed + mock_routing_table.get_provider_impl.assert_not_called() + + +async def test_update_vector_store_same_provider_id_succeeds(): + """Test that updating a vector store with the same provider_id succeeds.""" + mock_routing_table = Mock() + + # Mock an existing vector store with provider_id "faiss" + mock_existing_store = Mock() + mock_existing_store.provider_id = "inline::faiss" + mock_existing_store.identifier = "vs_123" + + mock_routing_table.get_object_by_identifier = AsyncMock(return_value=mock_existing_store) + mock_routing_table.get_provider_impl = AsyncMock( + return_value=Mock(openai_update_vector_store=AsyncMock(return_value=Mock(id="vs_123"))) + ) + + router = VectorIORouter(mock_routing_table) + + # Update with same provider_id should succeed + await router.openai_update_vector_store( + vector_store_id="vs_123", + name="updated_name", + metadata={"provider_id": "inline::faiss"}, # Same provider_id + ) + + # Verify the provider update method was called + mock_routing_table.get_provider_impl.assert_called_once_with("vs_123") + provider = await mock_routing_table.get_provider_impl("vs_123") + provider.openai_update_vector_store.assert_called_once_with( + vector_store_id="vs_123", name="updated_name", expires_after=None, metadata={"provider_id": "inline::faiss"} + ) diff --git a/tests/unit/server/test_query_params_middleware.py b/tests/unit/server/test_query_params_middleware.py new file mode 100644 index 000000000..30993576b --- /dev/null +++ b/tests/unit/server/test_query_params_middleware.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import AsyncMock, Mock + +from fastapi import Request + +from llama_stack.core.server.query_params_middleware import QueryParamsMiddleware + + +class TestQueryParamsMiddleware: + """Test cases for the QueryParamsMiddleware.""" + + async def test_extracts_query_params_for_vector_store_content(self): + """Test that middleware extracts query params for vector store content endpoints.""" + middleware = QueryParamsMiddleware(Mock()) + request = Mock(spec=Request) + request.method = "GET" + + # Mock the URL properly + mock_url = Mock() + mock_url.path = "/v1/vector_stores/vs_123/files/file_456/content" + request.url = mock_url + + request.query_params = {"include_embeddings": "true", "include_metadata": "false"} + + # Create a fresh state object without any attributes + class MockState: + pass + + request.state = MockState() + + await middleware.dispatch(request, AsyncMock()) + + assert hasattr(request.state, "extra_query") + assert request.state.extra_query == {"include_embeddings": True, "include_metadata": False} + + async def test_ignores_non_vector_store_endpoints(self): + """Test that middleware ignores non-vector store endpoints.""" + middleware = QueryParamsMiddleware(Mock()) + request = Mock(spec=Request) + request.method = "GET" + + # Mock the URL properly + mock_url = Mock() + mock_url.path = "/v1/inference/chat_completion" + request.url = mock_url + + request.query_params = {"include_embeddings": "true"} + + # Create a fresh state object without any attributes + class MockState: + pass + + request.state = MockState() + + await middleware.dispatch(request, AsyncMock()) + + assert not hasattr(request.state, "extra_query") + + async def test_handles_json_parsing(self): + """Test that middleware correctly parses JSON values and handles invalid JSON.""" + middleware = QueryParamsMiddleware(Mock()) + request = Mock(spec=Request) + request.method = "GET" + + # Mock the URL properly + mock_url = Mock() + mock_url.path = "/v1/vector_stores/vs_123/files/file_456/content" + request.url = mock_url + + request.query_params = {"config": '{"key": "value"}', "invalid": "not-json{", "number": "42"} + + # Create a fresh state object without any attributes + class MockState: + pass + + request.state = MockState() + + await middleware.dispatch(request, AsyncMock()) + + expected = {"config": {"key": "value"}, "invalid": "not-json{", "number": 42} + assert request.state.extra_query == expected diff --git a/tests/unit/server/test_sse.py b/tests/unit/server/test_sse.py index f36c8c181..0303a6ded 100644 --- a/tests/unit/server/test_sse.py +++ b/tests/unit/server/test_sse.py @@ -104,12 +104,18 @@ async def test_paginated_response_url_setting(): route_handler = create_dynamic_typed_route(mock_api_method, "get", "/test/route") - # Mock minimal request + # Mock minimal request with proper state object request = MagicMock() request.scope = {"user_attributes": {}, "principal": ""} request.headers = {} request.body = AsyncMock(return_value=b"") + # Create a simple state object without auto-generating attributes + class MockState: + pass + + request.state = MockState() + result = await route_handler(request) assert isinstance(result, PaginatedResponse)