feat: Adding optional embeddings to content

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-11-03 14:48:52 -05:00
parent 97ccfb5e62
commit aefbb6f9ea
20 changed files with 1314 additions and 132 deletions

View file

@ -0,0 +1,86 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, Mock
from fastapi import Request
from llama_stack.core.server.query_params_middleware import QueryParamsMiddleware
class TestQueryParamsMiddleware:
"""Test cases for the QueryParamsMiddleware."""
async def test_extracts_query_params_for_vector_store_content(self):
"""Test that middleware extracts query params for vector store content endpoints."""
middleware = QueryParamsMiddleware(Mock())
request = Mock(spec=Request)
request.method = "GET"
# Mock the URL properly
mock_url = Mock()
mock_url.path = "/v1/vector_stores/vs_123/files/file_456/content"
request.url = mock_url
request.query_params = {"include_embeddings": "true", "include_metadata": "false"}
# Create a fresh state object without any attributes
class MockState:
pass
request.state = MockState()
await middleware.dispatch(request, AsyncMock())
assert hasattr(request.state, "extra_query")
assert request.state.extra_query == {"include_embeddings": True, "include_metadata": False}
async def test_ignores_non_vector_store_endpoints(self):
"""Test that middleware ignores non-vector store endpoints."""
middleware = QueryParamsMiddleware(Mock())
request = Mock(spec=Request)
request.method = "GET"
# Mock the URL properly
mock_url = Mock()
mock_url.path = "/v1/inference/chat_completion"
request.url = mock_url
request.query_params = {"include_embeddings": "true"}
# Create a fresh state object without any attributes
class MockState:
pass
request.state = MockState()
await middleware.dispatch(request, AsyncMock())
assert not hasattr(request.state, "extra_query")
async def test_handles_json_parsing(self):
"""Test that middleware correctly parses JSON values and handles invalid JSON."""
middleware = QueryParamsMiddleware(Mock())
request = Mock(spec=Request)
request.method = "GET"
# Mock the URL properly
mock_url = Mock()
mock_url.path = "/v1/vector_stores/vs_123/files/file_456/content"
request.url = mock_url
request.query_params = {"config": '{"key": "value"}', "invalid": "not-json{", "number": "42"}
# Create a fresh state object without any attributes
class MockState:
pass
request.state = MockState()
await middleware.dispatch(request, AsyncMock())
expected = {"config": {"key": "value"}, "invalid": "not-json{", "number": 42}
assert request.state.extra_query == expected

View file

@ -104,12 +104,18 @@ async def test_paginated_response_url_setting():
route_handler = create_dynamic_typed_route(mock_api_method, "get", "/test/route")
# Mock minimal request
# Mock minimal request with proper state object
request = MagicMock()
request.scope = {"user_attributes": {}, "principal": ""}
request.headers = {}
request.body = AsyncMock(return_value=b"")
# Create a simple state object without auto-generating attributes
class MockState:
pass
request.state = MockState()
result = await route_handler(request)
assert isinstance(result, PaginatedResponse)