use FastAPI Query class instead of custom middlware

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-11-07 15:09:11 -05:00
parent aefbb6f9ea
commit 726bdc414d
7 changed files with 52 additions and 183 deletions

View file

@ -1,86 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, Mock
from fastapi import Request
from llama_stack.core.server.query_params_middleware import QueryParamsMiddleware
class TestQueryParamsMiddleware:
"""Test cases for the QueryParamsMiddleware."""
async def test_extracts_query_params_for_vector_store_content(self):
"""Test that middleware extracts query params for vector store content endpoints."""
middleware = QueryParamsMiddleware(Mock())
request = Mock(spec=Request)
request.method = "GET"
# Mock the URL properly
mock_url = Mock()
mock_url.path = "/v1/vector_stores/vs_123/files/file_456/content"
request.url = mock_url
request.query_params = {"include_embeddings": "true", "include_metadata": "false"}
# Create a fresh state object without any attributes
class MockState:
pass
request.state = MockState()
await middleware.dispatch(request, AsyncMock())
assert hasattr(request.state, "extra_query")
assert request.state.extra_query == {"include_embeddings": True, "include_metadata": False}
async def test_ignores_non_vector_store_endpoints(self):
"""Test that middleware ignores non-vector store endpoints."""
middleware = QueryParamsMiddleware(Mock())
request = Mock(spec=Request)
request.method = "GET"
# Mock the URL properly
mock_url = Mock()
mock_url.path = "/v1/inference/chat_completion"
request.url = mock_url
request.query_params = {"include_embeddings": "true"}
# Create a fresh state object without any attributes
class MockState:
pass
request.state = MockState()
await middleware.dispatch(request, AsyncMock())
assert not hasattr(request.state, "extra_query")
async def test_handles_json_parsing(self):
"""Test that middleware correctly parses JSON values and handles invalid JSON."""
middleware = QueryParamsMiddleware(Mock())
request = Mock(spec=Request)
request.method = "GET"
# Mock the URL properly
mock_url = Mock()
mock_url.path = "/v1/vector_stores/vs_123/files/file_456/content"
request.url = mock_url
request.query_params = {"config": '{"key": "value"}', "invalid": "not-json{", "number": "42"}
# Create a fresh state object without any attributes
class MockState:
pass
request.state = MockState()
await middleware.dispatch(request, AsyncMock())
expected = {"config": {"key": "value"}, "invalid": "not-json{", "number": 42}
assert request.state.extra_query == expected