mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 01:48:05 +00:00
use FastAPI Query class instead of custom middlware
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
aefbb6f9ea
commit
726bdc414d
7 changed files with 52 additions and 183 deletions
|
|
@ -2691,10 +2691,14 @@ paths:
|
|||
responses:
|
||||
'200':
|
||||
description: >-
|
||||
<<<<<<< HEAD
|
||||
<<<<<<< HEAD
|
||||
A VectorStoreFileContentResponse representing the file contents.
|
||||
=======
|
||||
File contents, optionally with embeddings and metadata based on extra_query
|
||||
=======
|
||||
File contents, optionally with embeddings and metadata based on query
|
||||
>>>>>>> c192529c (use FastAPI Query class instead of custom middlware)
|
||||
parameters.
|
||||
>>>>>>> 639f0daa (feat: Adding optional embeddings to content)
|
||||
content:
|
||||
|
|
@ -2731,23 +2735,20 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: extra_query
|
||||
- name: include_embeddings
|
||||
in: query
|
||||
description: >-
|
||||
Optional extra parameters to control response format. Set include_embeddings=true
|
||||
to include embedding vectors. Set include_metadata=true to include chunk
|
||||
metadata.
|
||||
Whether to include embedding vectors in the response.
|
||||
required: false
|
||||
schema:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
$ref: '#/components/schemas/bool'
|
||||
- name: include_metadata
|
||||
in: query
|
||||
description: >-
|
||||
Whether to include chunk metadata in the response.
|
||||
required: false
|
||||
schema:
|
||||
$ref: '#/components/schemas/bool'
|
||||
deprecated: false
|
||||
/v1/vector_stores/{vector_store_id}/search:
|
||||
post:
|
||||
|
|
@ -10113,6 +10114,8 @@ components:
|
|||
title: VectorStoreFileDeleteResponse
|
||||
description: >-
|
||||
Response from deleting a vector store file.
|
||||
bool:
|
||||
type: boolean
|
||||
VectorStoreContent:
|
||||
type: object
|
||||
properties:
|
||||
|
|
|
|||
29
docs/static/llama-stack-spec.yaml
vendored
29
docs/static/llama-stack-spec.yaml
vendored
|
|
@ -2688,10 +2688,14 @@ paths:
|
|||
responses:
|
||||
'200':
|
||||
description: >-
|
||||
<<<<<<< HEAD
|
||||
<<<<<<< HEAD
|
||||
A VectorStoreFileContentResponse representing the file contents.
|
||||
=======
|
||||
File contents, optionally with embeddings and metadata based on extra_query
|
||||
=======
|
||||
File contents, optionally with embeddings and metadata based on query
|
||||
>>>>>>> c192529c (use FastAPI Query class instead of custom middlware)
|
||||
parameters.
|
||||
>>>>>>> 639f0daa (feat: Adding optional embeddings to content)
|
||||
content:
|
||||
|
|
@ -2728,23 +2732,20 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: extra_query
|
||||
- name: include_embeddings
|
||||
in: query
|
||||
description: >-
|
||||
Optional extra parameters to control response format. Set include_embeddings=true
|
||||
to include embedding vectors. Set include_metadata=true to include chunk
|
||||
metadata.
|
||||
Whether to include embedding vectors in the response.
|
||||
required: false
|
||||
schema:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
$ref: '#/components/schemas/bool'
|
||||
- name: include_metadata
|
||||
in: query
|
||||
description: >-
|
||||
Whether to include chunk metadata in the response.
|
||||
required: false
|
||||
schema:
|
||||
$ref: '#/components/schemas/bool'
|
||||
deprecated: false
|
||||
/v1/vector_stores/{vector_store_id}/search:
|
||||
post:
|
||||
|
|
@ -9397,6 +9398,8 @@ components:
|
|||
title: VectorStoreFileDeleteResponse
|
||||
description: >-
|
||||
Response from deleting a vector store file.
|
||||
bool:
|
||||
type: boolean
|
||||
VectorStoreContent:
|
||||
type: object
|
||||
properties:
|
||||
|
|
|
|||
29
docs/static/stainless-llama-stack-spec.yaml
vendored
29
docs/static/stainless-llama-stack-spec.yaml
vendored
|
|
@ -2691,10 +2691,14 @@ paths:
|
|||
responses:
|
||||
'200':
|
||||
description: >-
|
||||
<<<<<<< HEAD
|
||||
<<<<<<< HEAD
|
||||
A VectorStoreFileContentResponse representing the file contents.
|
||||
=======
|
||||
File contents, optionally with embeddings and metadata based on extra_query
|
||||
=======
|
||||
File contents, optionally with embeddings and metadata based on query
|
||||
>>>>>>> c192529c (use FastAPI Query class instead of custom middlware)
|
||||
parameters.
|
||||
>>>>>>> 639f0daa (feat: Adding optional embeddings to content)
|
||||
content:
|
||||
|
|
@ -2731,23 +2735,20 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: extra_query
|
||||
- name: include_embeddings
|
||||
in: query
|
||||
description: >-
|
||||
Optional extra parameters to control response format. Set include_embeddings=true
|
||||
to include embedding vectors. Set include_metadata=true to include chunk
|
||||
metadata.
|
||||
Whether to include embedding vectors in the response.
|
||||
required: false
|
||||
schema:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
$ref: '#/components/schemas/bool'
|
||||
- name: include_metadata
|
||||
in: query
|
||||
description: >-
|
||||
Whether to include chunk metadata in the response.
|
||||
required: false
|
||||
schema:
|
||||
$ref: '#/components/schemas/bool'
|
||||
deprecated: false
|
||||
/v1/vector_stores/{vector_store_id}/search:
|
||||
post:
|
||||
|
|
@ -10113,6 +10114,8 @@ components:
|
|||
title: VectorStoreFileDeleteResponse
|
||||
description: >-
|
||||
Response from deleting a vector store file.
|
||||
bool:
|
||||
type: boolean
|
||||
VectorStoreContent:
|
||||
type: object
|
||||
properties:
|
||||
|
|
|
|||
|
|
@ -1,49 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="core::middleware")
|
||||
|
||||
# Patterns for endpoints that need query parameter injection
|
||||
QUERY_PARAM_ENDPOINTS = [
|
||||
# /vector_stores/{vector_store_id}/files/{file_id}/content
|
||||
re.compile(r"/vector_stores/[^/]+/files/[^/]+/content$"),
|
||||
]
|
||||
|
||||
|
||||
class QueryParamsMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to inject query parameters into extra_query for specific endpoints"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Check if this is an endpoint that needs query parameter injection
|
||||
if request.method == "GET" and any(pattern.search(str(request.url.path)) for pattern in QUERY_PARAM_ENDPOINTS):
|
||||
# Extract all query parameters and convert to appropriate types
|
||||
extra_query = {}
|
||||
query_params = dict(request.query_params)
|
||||
|
||||
# Convert query parameters using JSON parsing for robust type conversion
|
||||
for key, value in query_params.items():
|
||||
try:
|
||||
# parse as JSON to handles booleans, numbers, strings properly
|
||||
extra_query[key] = json.loads(value)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
# if parsing fails, keep as string
|
||||
extra_query[key] = value
|
||||
|
||||
if extra_query:
|
||||
# Store the extra_query in request state so we can access it later
|
||||
request.state.extra_query = extra_query
|
||||
logger.debug(f"QueryParamsMiddleware extracted extra_query: {extra_query}")
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
|
@ -46,7 +46,6 @@ from llama_stack.core.request_headers import (
|
|||
request_provider_data_context,
|
||||
user_from_scope,
|
||||
)
|
||||
from llama_stack.core.server.query_params_middleware import QueryParamsMiddleware
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
from llama_stack.core.stack import (
|
||||
Stack,
|
||||
|
|
@ -264,10 +263,6 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
|
||||
# Inject extra_query from middleware if available
|
||||
if hasattr(request.state, "extra_query") and request.state.extra_query:
|
||||
kwargs["extra_query"] = request.state.extra_query
|
||||
|
||||
try:
|
||||
if is_streaming:
|
||||
context_vars = [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
|
||||
|
|
@ -407,9 +402,6 @@ def create_app() -> StackApp:
|
|||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||
app.add_middleware(ClientVersionMiddleware)
|
||||
|
||||
# handle extra_query for specific GET requests
|
||||
app.add_middleware(QueryParamsMiddleware)
|
||||
|
||||
impls = app.stack.impls
|
||||
|
||||
if config.server.auth:
|
||||
|
|
|
|||
|
|
@ -450,7 +450,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
# Now that our vector store is created, attach any files that were provided
|
||||
file_ids = params.file_ids or []
|
||||
tasks = [self.openai_attach_file_to_vector_store(vector_store_id, file_id) for file_id in file_ids]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# Get the updated store info and return it
|
||||
store_info = self.openai_vector_stores[vector_store_id]
|
||||
|
|
@ -928,6 +928,9 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
if vector_store_id not in self.openai_vector_stores:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
# Parameters are already provided directly
|
||||
# include_embeddings and include_metadata are now function parameters
|
||||
|
||||
file_info = await self._load_openai_vector_store_file(vector_store_id, file_id)
|
||||
dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id)
|
||||
chunks = [Chunk.model_validate(c) for c in dict_chunks]
|
||||
|
|
|
|||
|
|
@ -1,86 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from llama_stack.core.server.query_params_middleware import QueryParamsMiddleware
|
||||
|
||||
|
||||
class TestQueryParamsMiddleware:
|
||||
"""Test cases for the QueryParamsMiddleware."""
|
||||
|
||||
async def test_extracts_query_params_for_vector_store_content(self):
|
||||
"""Test that middleware extracts query params for vector store content endpoints."""
|
||||
middleware = QueryParamsMiddleware(Mock())
|
||||
request = Mock(spec=Request)
|
||||
request.method = "GET"
|
||||
|
||||
# Mock the URL properly
|
||||
mock_url = Mock()
|
||||
mock_url.path = "/v1/vector_stores/vs_123/files/file_456/content"
|
||||
request.url = mock_url
|
||||
|
||||
request.query_params = {"include_embeddings": "true", "include_metadata": "false"}
|
||||
|
||||
# Create a fresh state object without any attributes
|
||||
class MockState:
|
||||
pass
|
||||
|
||||
request.state = MockState()
|
||||
|
||||
await middleware.dispatch(request, AsyncMock())
|
||||
|
||||
assert hasattr(request.state, "extra_query")
|
||||
assert request.state.extra_query == {"include_embeddings": True, "include_metadata": False}
|
||||
|
||||
async def test_ignores_non_vector_store_endpoints(self):
|
||||
"""Test that middleware ignores non-vector store endpoints."""
|
||||
middleware = QueryParamsMiddleware(Mock())
|
||||
request = Mock(spec=Request)
|
||||
request.method = "GET"
|
||||
|
||||
# Mock the URL properly
|
||||
mock_url = Mock()
|
||||
mock_url.path = "/v1/inference/chat_completion"
|
||||
request.url = mock_url
|
||||
|
||||
request.query_params = {"include_embeddings": "true"}
|
||||
|
||||
# Create a fresh state object without any attributes
|
||||
class MockState:
|
||||
pass
|
||||
|
||||
request.state = MockState()
|
||||
|
||||
await middleware.dispatch(request, AsyncMock())
|
||||
|
||||
assert not hasattr(request.state, "extra_query")
|
||||
|
||||
async def test_handles_json_parsing(self):
|
||||
"""Test that middleware correctly parses JSON values and handles invalid JSON."""
|
||||
middleware = QueryParamsMiddleware(Mock())
|
||||
request = Mock(spec=Request)
|
||||
request.method = "GET"
|
||||
|
||||
# Mock the URL properly
|
||||
mock_url = Mock()
|
||||
mock_url.path = "/v1/vector_stores/vs_123/files/file_456/content"
|
||||
request.url = mock_url
|
||||
|
||||
request.query_params = {"config": '{"key": "value"}', "invalid": "not-json{", "number": "42"}
|
||||
|
||||
# Create a fresh state object without any attributes
|
||||
class MockState:
|
||||
pass
|
||||
|
||||
request.state = MockState()
|
||||
|
||||
await middleware.dispatch(request, AsyncMock())
|
||||
|
||||
expected = {"config": {"key": "value"}, "invalid": "not-json{", "number": 42}
|
||||
assert request.state.extra_query == expected
|
||||
Loading…
Add table
Add a link
Reference in a new issue