mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 10:10:36 +00:00
use FastAPI Query class instead of custom middlware
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
aefbb6f9ea
commit
726bdc414d
7 changed files with 52 additions and 183 deletions
|
|
@ -1,86 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from llama_stack.core.server.query_params_middleware import QueryParamsMiddleware
|
||||
|
||||
|
||||
class TestQueryParamsMiddleware:
|
||||
"""Test cases for the QueryParamsMiddleware."""
|
||||
|
||||
async def test_extracts_query_params_for_vector_store_content(self):
|
||||
"""Test that middleware extracts query params for vector store content endpoints."""
|
||||
middleware = QueryParamsMiddleware(Mock())
|
||||
request = Mock(spec=Request)
|
||||
request.method = "GET"
|
||||
|
||||
# Mock the URL properly
|
||||
mock_url = Mock()
|
||||
mock_url.path = "/v1/vector_stores/vs_123/files/file_456/content"
|
||||
request.url = mock_url
|
||||
|
||||
request.query_params = {"include_embeddings": "true", "include_metadata": "false"}
|
||||
|
||||
# Create a fresh state object without any attributes
|
||||
class MockState:
|
||||
pass
|
||||
|
||||
request.state = MockState()
|
||||
|
||||
await middleware.dispatch(request, AsyncMock())
|
||||
|
||||
assert hasattr(request.state, "extra_query")
|
||||
assert request.state.extra_query == {"include_embeddings": True, "include_metadata": False}
|
||||
|
||||
async def test_ignores_non_vector_store_endpoints(self):
|
||||
"""Test that middleware ignores non-vector store endpoints."""
|
||||
middleware = QueryParamsMiddleware(Mock())
|
||||
request = Mock(spec=Request)
|
||||
request.method = "GET"
|
||||
|
||||
# Mock the URL properly
|
||||
mock_url = Mock()
|
||||
mock_url.path = "/v1/inference/chat_completion"
|
||||
request.url = mock_url
|
||||
|
||||
request.query_params = {"include_embeddings": "true"}
|
||||
|
||||
# Create a fresh state object without any attributes
|
||||
class MockState:
|
||||
pass
|
||||
|
||||
request.state = MockState()
|
||||
|
||||
await middleware.dispatch(request, AsyncMock())
|
||||
|
||||
assert not hasattr(request.state, "extra_query")
|
||||
|
||||
async def test_handles_json_parsing(self):
|
||||
"""Test that middleware correctly parses JSON values and handles invalid JSON."""
|
||||
middleware = QueryParamsMiddleware(Mock())
|
||||
request = Mock(spec=Request)
|
||||
request.method = "GET"
|
||||
|
||||
# Mock the URL properly
|
||||
mock_url = Mock()
|
||||
mock_url.path = "/v1/vector_stores/vs_123/files/file_456/content"
|
||||
request.url = mock_url
|
||||
|
||||
request.query_params = {"config": '{"key": "value"}', "invalid": "not-json{", "number": "42"}
|
||||
|
||||
# Create a fresh state object without any attributes
|
||||
class MockState:
|
||||
pass
|
||||
|
||||
request.state = MockState()
|
||||
|
||||
await middleware.dispatch(request, AsyncMock())
|
||||
|
||||
expected = {"config": {"key": "value"}, "invalid": "not-json{", "number": 42}
|
||||
assert request.state.extra_query == expected
|
||||
Loading…
Add table
Add a link
Reference in a new issue