mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
86 lines
2.8 KiB
Python
86 lines
2.8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from unittest.mock import AsyncMock, Mock
|
|
|
|
from fastapi import Request
|
|
|
|
from llama_stack.core.server.query_params_middleware import QueryParamsMiddleware
|
|
|
|
|
|
class TestQueryParamsMiddleware:
|
|
"""Test cases for the QueryParamsMiddleware."""
|
|
|
|
async def test_extracts_query_params_for_vector_store_content(self):
|
|
"""Test that middleware extracts query params for vector store content endpoints."""
|
|
middleware = QueryParamsMiddleware(Mock())
|
|
request = Mock(spec=Request)
|
|
request.method = "GET"
|
|
|
|
# Mock the URL properly
|
|
mock_url = Mock()
|
|
mock_url.path = "/v1/vector_stores/vs_123/files/file_456/content"
|
|
request.url = mock_url
|
|
|
|
request.query_params = {"include_embeddings": "true", "include_metadata": "false"}
|
|
|
|
# Create a fresh state object without any attributes
|
|
class MockState:
|
|
pass
|
|
|
|
request.state = MockState()
|
|
|
|
await middleware.dispatch(request, AsyncMock())
|
|
|
|
assert hasattr(request.state, "extra_query")
|
|
assert request.state.extra_query == {"include_embeddings": True, "include_metadata": False}
|
|
|
|
async def test_ignores_non_vector_store_endpoints(self):
|
|
"""Test that middleware ignores non-vector store endpoints."""
|
|
middleware = QueryParamsMiddleware(Mock())
|
|
request = Mock(spec=Request)
|
|
request.method = "GET"
|
|
|
|
# Mock the URL properly
|
|
mock_url = Mock()
|
|
mock_url.path = "/v1/inference/chat_completion"
|
|
request.url = mock_url
|
|
|
|
request.query_params = {"include_embeddings": "true"}
|
|
|
|
# Create a fresh state object without any attributes
|
|
class MockState:
|
|
pass
|
|
|
|
request.state = MockState()
|
|
|
|
await middleware.dispatch(request, AsyncMock())
|
|
|
|
assert not hasattr(request.state, "extra_query")
|
|
|
|
async def test_handles_json_parsing(self):
|
|
"""Test that middleware correctly parses JSON values and handles invalid JSON."""
|
|
middleware = QueryParamsMiddleware(Mock())
|
|
request = Mock(spec=Request)
|
|
request.method = "GET"
|
|
|
|
# Mock the URL properly
|
|
mock_url = Mock()
|
|
mock_url.path = "/v1/vector_stores/vs_123/files/file_456/content"
|
|
request.url = mock_url
|
|
|
|
request.query_params = {"config": '{"key": "value"}', "invalid": "not-json{", "number": "42"}
|
|
|
|
# Create a fresh state object without any attributes
|
|
class MockState:
|
|
pass
|
|
|
|
request.state = MockState()
|
|
|
|
await middleware.dispatch(request, AsyncMock())
|
|
|
|
expected = {"config": {"key": "value"}, "invalid": "not-json{", "number": 42}
|
|
assert request.state.extra_query == expected
|