mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
use FastAPI Query class instead of custom middlware
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
aefbb6f9ea
commit
726bdc414d
7 changed files with 52 additions and 183 deletions
|
|
@ -2691,10 +2691,14 @@ paths:
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: >-
|
description: >-
|
||||||
|
<<<<<<< HEAD
|
||||||
<<<<<<< HEAD
|
<<<<<<< HEAD
|
||||||
A VectorStoreFileContentResponse representing the file contents.
|
A VectorStoreFileContentResponse representing the file contents.
|
||||||
=======
|
=======
|
||||||
File contents, optionally with embeddings and metadata based on extra_query
|
File contents, optionally with embeddings and metadata based on extra_query
|
||||||
|
=======
|
||||||
|
File contents, optionally with embeddings and metadata based on query
|
||||||
|
>>>>>>> c192529c (use FastAPI Query class instead of custom middlware)
|
||||||
parameters.
|
parameters.
|
||||||
>>>>>>> 639f0daa (feat: Adding optional embeddings to content)
|
>>>>>>> 639f0daa (feat: Adding optional embeddings to content)
|
||||||
content:
|
content:
|
||||||
|
|
@ -2731,23 +2735,20 @@ paths:
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
- name: extra_query
|
- name: include_embeddings
|
||||||
in: query
|
in: query
|
||||||
description: >-
|
description: >-
|
||||||
Optional extra parameters to control response format. Set include_embeddings=true
|
Whether to include embedding vectors in the response.
|
||||||
to include embedding vectors. Set include_metadata=true to include chunk
|
|
||||||
metadata.
|
|
||||||
required: false
|
required: false
|
||||||
schema:
|
schema:
|
||||||
type: object
|
$ref: '#/components/schemas/bool'
|
||||||
additionalProperties:
|
- name: include_metadata
|
||||||
oneOf:
|
in: query
|
||||||
- type: 'null'
|
description: >-
|
||||||
- type: boolean
|
Whether to include chunk metadata in the response.
|
||||||
- type: number
|
required: false
|
||||||
- type: string
|
schema:
|
||||||
- type: array
|
$ref: '#/components/schemas/bool'
|
||||||
- type: object
|
|
||||||
deprecated: false
|
deprecated: false
|
||||||
/v1/vector_stores/{vector_store_id}/search:
|
/v1/vector_stores/{vector_store_id}/search:
|
||||||
post:
|
post:
|
||||||
|
|
@ -10113,6 +10114,8 @@ components:
|
||||||
title: VectorStoreFileDeleteResponse
|
title: VectorStoreFileDeleteResponse
|
||||||
description: >-
|
description: >-
|
||||||
Response from deleting a vector store file.
|
Response from deleting a vector store file.
|
||||||
|
bool:
|
||||||
|
type: boolean
|
||||||
VectorStoreContent:
|
VectorStoreContent:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
|
||||||
29
docs/static/llama-stack-spec.yaml
vendored
29
docs/static/llama-stack-spec.yaml
vendored
|
|
@ -2688,10 +2688,14 @@ paths:
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: >-
|
description: >-
|
||||||
|
<<<<<<< HEAD
|
||||||
<<<<<<< HEAD
|
<<<<<<< HEAD
|
||||||
A VectorStoreFileContentResponse representing the file contents.
|
A VectorStoreFileContentResponse representing the file contents.
|
||||||
=======
|
=======
|
||||||
File contents, optionally with embeddings and metadata based on extra_query
|
File contents, optionally with embeddings and metadata based on extra_query
|
||||||
|
=======
|
||||||
|
File contents, optionally with embeddings and metadata based on query
|
||||||
|
>>>>>>> c192529c (use FastAPI Query class instead of custom middlware)
|
||||||
parameters.
|
parameters.
|
||||||
>>>>>>> 639f0daa (feat: Adding optional embeddings to content)
|
>>>>>>> 639f0daa (feat: Adding optional embeddings to content)
|
||||||
content:
|
content:
|
||||||
|
|
@ -2728,23 +2732,20 @@ paths:
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
- name: extra_query
|
- name: include_embeddings
|
||||||
in: query
|
in: query
|
||||||
description: >-
|
description: >-
|
||||||
Optional extra parameters to control response format. Set include_embeddings=true
|
Whether to include embedding vectors in the response.
|
||||||
to include embedding vectors. Set include_metadata=true to include chunk
|
|
||||||
metadata.
|
|
||||||
required: false
|
required: false
|
||||||
schema:
|
schema:
|
||||||
type: object
|
$ref: '#/components/schemas/bool'
|
||||||
additionalProperties:
|
- name: include_metadata
|
||||||
oneOf:
|
in: query
|
||||||
- type: 'null'
|
description: >-
|
||||||
- type: boolean
|
Whether to include chunk metadata in the response.
|
||||||
- type: number
|
required: false
|
||||||
- type: string
|
schema:
|
||||||
- type: array
|
$ref: '#/components/schemas/bool'
|
||||||
- type: object
|
|
||||||
deprecated: false
|
deprecated: false
|
||||||
/v1/vector_stores/{vector_store_id}/search:
|
/v1/vector_stores/{vector_store_id}/search:
|
||||||
post:
|
post:
|
||||||
|
|
@ -9397,6 +9398,8 @@ components:
|
||||||
title: VectorStoreFileDeleteResponse
|
title: VectorStoreFileDeleteResponse
|
||||||
description: >-
|
description: >-
|
||||||
Response from deleting a vector store file.
|
Response from deleting a vector store file.
|
||||||
|
bool:
|
||||||
|
type: boolean
|
||||||
VectorStoreContent:
|
VectorStoreContent:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
|
||||||
29
docs/static/stainless-llama-stack-spec.yaml
vendored
29
docs/static/stainless-llama-stack-spec.yaml
vendored
|
|
@ -2691,10 +2691,14 @@ paths:
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: >-
|
description: >-
|
||||||
|
<<<<<<< HEAD
|
||||||
<<<<<<< HEAD
|
<<<<<<< HEAD
|
||||||
A VectorStoreFileContentResponse representing the file contents.
|
A VectorStoreFileContentResponse representing the file contents.
|
||||||
=======
|
=======
|
||||||
File contents, optionally with embeddings and metadata based on extra_query
|
File contents, optionally with embeddings and metadata based on extra_query
|
||||||
|
=======
|
||||||
|
File contents, optionally with embeddings and metadata based on query
|
||||||
|
>>>>>>> c192529c (use FastAPI Query class instead of custom middlware)
|
||||||
parameters.
|
parameters.
|
||||||
>>>>>>> 639f0daa (feat: Adding optional embeddings to content)
|
>>>>>>> 639f0daa (feat: Adding optional embeddings to content)
|
||||||
content:
|
content:
|
||||||
|
|
@ -2731,23 +2735,20 @@ paths:
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
- name: extra_query
|
- name: include_embeddings
|
||||||
in: query
|
in: query
|
||||||
description: >-
|
description: >-
|
||||||
Optional extra parameters to control response format. Set include_embeddings=true
|
Whether to include embedding vectors in the response.
|
||||||
to include embedding vectors. Set include_metadata=true to include chunk
|
|
||||||
metadata.
|
|
||||||
required: false
|
required: false
|
||||||
schema:
|
schema:
|
||||||
type: object
|
$ref: '#/components/schemas/bool'
|
||||||
additionalProperties:
|
- name: include_metadata
|
||||||
oneOf:
|
in: query
|
||||||
- type: 'null'
|
description: >-
|
||||||
- type: boolean
|
Whether to include chunk metadata in the response.
|
||||||
- type: number
|
required: false
|
||||||
- type: string
|
schema:
|
||||||
- type: array
|
$ref: '#/components/schemas/bool'
|
||||||
- type: object
|
|
||||||
deprecated: false
|
deprecated: false
|
||||||
/v1/vector_stores/{vector_store_id}/search:
|
/v1/vector_stores/{vector_store_id}/search:
|
||||||
post:
|
post:
|
||||||
|
|
@ -10113,6 +10114,8 @@ components:
|
||||||
title: VectorStoreFileDeleteResponse
|
title: VectorStoreFileDeleteResponse
|
||||||
description: >-
|
description: >-
|
||||||
Response from deleting a vector store file.
|
Response from deleting a vector store file.
|
||||||
|
bool:
|
||||||
|
type: boolean
|
||||||
VectorStoreContent:
|
VectorStoreContent:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
|
||||||
|
|
@ -1,49 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
|
|
||||||
from fastapi import Request
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core::middleware")
|
|
||||||
|
|
||||||
# Patterns for endpoints that need query parameter injection
|
|
||||||
QUERY_PARAM_ENDPOINTS = [
|
|
||||||
# /vector_stores/{vector_store_id}/files/{file_id}/content
|
|
||||||
re.compile(r"/vector_stores/[^/]+/files/[^/]+/content$"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class QueryParamsMiddleware(BaseHTTPMiddleware):
|
|
||||||
"""Middleware to inject query parameters into extra_query for specific endpoints"""
|
|
||||||
|
|
||||||
async def dispatch(self, request: Request, call_next):
|
|
||||||
# Check if this is an endpoint that needs query parameter injection
|
|
||||||
if request.method == "GET" and any(pattern.search(str(request.url.path)) for pattern in QUERY_PARAM_ENDPOINTS):
|
|
||||||
# Extract all query parameters and convert to appropriate types
|
|
||||||
extra_query = {}
|
|
||||||
query_params = dict(request.query_params)
|
|
||||||
|
|
||||||
# Convert query parameters using JSON parsing for robust type conversion
|
|
||||||
for key, value in query_params.items():
|
|
||||||
try:
|
|
||||||
# parse as JSON to handles booleans, numbers, strings properly
|
|
||||||
extra_query[key] = json.loads(value)
|
|
||||||
except (json.JSONDecodeError, ValueError):
|
|
||||||
# if parsing fails, keep as string
|
|
||||||
extra_query[key] = value
|
|
||||||
|
|
||||||
if extra_query:
|
|
||||||
# Store the extra_query in request state so we can access it later
|
|
||||||
request.state.extra_query = extra_query
|
|
||||||
logger.debug(f"QueryParamsMiddleware extracted extra_query: {extra_query}")
|
|
||||||
|
|
||||||
response = await call_next(request)
|
|
||||||
return response
|
|
||||||
|
|
@ -46,7 +46,6 @@ from llama_stack.core.request_headers import (
|
||||||
request_provider_data_context,
|
request_provider_data_context,
|
||||||
user_from_scope,
|
user_from_scope,
|
||||||
)
|
)
|
||||||
from llama_stack.core.server.query_params_middleware import QueryParamsMiddleware
|
|
||||||
from llama_stack.core.server.routes import get_all_api_routes
|
from llama_stack.core.server.routes import get_all_api_routes
|
||||||
from llama_stack.core.stack import (
|
from llama_stack.core.stack import (
|
||||||
Stack,
|
Stack,
|
||||||
|
|
@ -264,10 +263,6 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
||||||
|
|
||||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||||
|
|
||||||
# Inject extra_query from middleware if available
|
|
||||||
if hasattr(request.state, "extra_query") and request.state.extra_query:
|
|
||||||
kwargs["extra_query"] = request.state.extra_query
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if is_streaming:
|
if is_streaming:
|
||||||
context_vars = [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
|
context_vars = [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
|
||||||
|
|
@ -407,9 +402,6 @@ def create_app() -> StackApp:
|
||||||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||||
app.add_middleware(ClientVersionMiddleware)
|
app.add_middleware(ClientVersionMiddleware)
|
||||||
|
|
||||||
# handle extra_query for specific GET requests
|
|
||||||
app.add_middleware(QueryParamsMiddleware)
|
|
||||||
|
|
||||||
impls = app.stack.impls
|
impls = app.stack.impls
|
||||||
|
|
||||||
if config.server.auth:
|
if config.server.auth:
|
||||||
|
|
|
||||||
|
|
@ -450,7 +450,7 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
# Now that our vector store is created, attach any files that were provided
|
# Now that our vector store is created, attach any files that were provided
|
||||||
file_ids = params.file_ids or []
|
file_ids = params.file_ids or []
|
||||||
tasks = [self.openai_attach_file_to_vector_store(vector_store_id, file_id) for file_id in file_ids]
|
tasks = [self.openai_attach_file_to_vector_store(vector_store_id, file_id) for file_id in file_ids]
|
||||||
await asyncio.gather(*tasks, return_exceptions=True)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
# Get the updated store info and return it
|
# Get the updated store info and return it
|
||||||
store_info = self.openai_vector_stores[vector_store_id]
|
store_info = self.openai_vector_stores[vector_store_id]
|
||||||
|
|
@ -928,6 +928,9 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
if vector_store_id not in self.openai_vector_stores:
|
if vector_store_id not in self.openai_vector_stores:
|
||||||
raise VectorStoreNotFoundError(vector_store_id)
|
raise VectorStoreNotFoundError(vector_store_id)
|
||||||
|
|
||||||
|
# Parameters are already provided directly
|
||||||
|
# include_embeddings and include_metadata are now function parameters
|
||||||
|
|
||||||
file_info = await self._load_openai_vector_store_file(vector_store_id, file_id)
|
file_info = await self._load_openai_vector_store_file(vector_store_id, file_id)
|
||||||
dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id)
|
dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id)
|
||||||
chunks = [Chunk.model_validate(c) for c in dict_chunks]
|
chunks = [Chunk.model_validate(c) for c in dict_chunks]
|
||||||
|
|
|
||||||
|
|
@ -1,86 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, Mock
|
|
||||||
|
|
||||||
from fastapi import Request
|
|
||||||
|
|
||||||
from llama_stack.core.server.query_params_middleware import QueryParamsMiddleware
|
|
||||||
|
|
||||||
|
|
||||||
class TestQueryParamsMiddleware:
|
|
||||||
"""Test cases for the QueryParamsMiddleware."""
|
|
||||||
|
|
||||||
async def test_extracts_query_params_for_vector_store_content(self):
|
|
||||||
"""Test that middleware extracts query params for vector store content endpoints."""
|
|
||||||
middleware = QueryParamsMiddleware(Mock())
|
|
||||||
request = Mock(spec=Request)
|
|
||||||
request.method = "GET"
|
|
||||||
|
|
||||||
# Mock the URL properly
|
|
||||||
mock_url = Mock()
|
|
||||||
mock_url.path = "/v1/vector_stores/vs_123/files/file_456/content"
|
|
||||||
request.url = mock_url
|
|
||||||
|
|
||||||
request.query_params = {"include_embeddings": "true", "include_metadata": "false"}
|
|
||||||
|
|
||||||
# Create a fresh state object without any attributes
|
|
||||||
class MockState:
|
|
||||||
pass
|
|
||||||
|
|
||||||
request.state = MockState()
|
|
||||||
|
|
||||||
await middleware.dispatch(request, AsyncMock())
|
|
||||||
|
|
||||||
assert hasattr(request.state, "extra_query")
|
|
||||||
assert request.state.extra_query == {"include_embeddings": True, "include_metadata": False}
|
|
||||||
|
|
||||||
async def test_ignores_non_vector_store_endpoints(self):
|
|
||||||
"""Test that middleware ignores non-vector store endpoints."""
|
|
||||||
middleware = QueryParamsMiddleware(Mock())
|
|
||||||
request = Mock(spec=Request)
|
|
||||||
request.method = "GET"
|
|
||||||
|
|
||||||
# Mock the URL properly
|
|
||||||
mock_url = Mock()
|
|
||||||
mock_url.path = "/v1/inference/chat_completion"
|
|
||||||
request.url = mock_url
|
|
||||||
|
|
||||||
request.query_params = {"include_embeddings": "true"}
|
|
||||||
|
|
||||||
# Create a fresh state object without any attributes
|
|
||||||
class MockState:
|
|
||||||
pass
|
|
||||||
|
|
||||||
request.state = MockState()
|
|
||||||
|
|
||||||
await middleware.dispatch(request, AsyncMock())
|
|
||||||
|
|
||||||
assert not hasattr(request.state, "extra_query")
|
|
||||||
|
|
||||||
async def test_handles_json_parsing(self):
|
|
||||||
"""Test that middleware correctly parses JSON values and handles invalid JSON."""
|
|
||||||
middleware = QueryParamsMiddleware(Mock())
|
|
||||||
request = Mock(spec=Request)
|
|
||||||
request.method = "GET"
|
|
||||||
|
|
||||||
# Mock the URL properly
|
|
||||||
mock_url = Mock()
|
|
||||||
mock_url.path = "/v1/vector_stores/vs_123/files/file_456/content"
|
|
||||||
request.url = mock_url
|
|
||||||
|
|
||||||
request.query_params = {"config": '{"key": "value"}', "invalid": "not-json{", "number": "42"}
|
|
||||||
|
|
||||||
# Create a fresh state object without any attributes
|
|
||||||
class MockState:
|
|
||||||
pass
|
|
||||||
|
|
||||||
request.state = MockState()
|
|
||||||
|
|
||||||
await middleware.dispatch(request, AsyncMock())
|
|
||||||
|
|
||||||
expected = {"config": {"key": "value"}, "invalid": "not-json{", "number": 42}
|
|
||||||
assert request.state.extra_query == expected
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue