From 606f4cf2819fd6a79c09d13c97685f96add18b2e Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 30 Sep 2025 13:14:03 -0700 Subject: [PATCH] fix(expires_after): make sure multipart/form-data is properly parsed (#3612) https://github.com/llamastack/llama-stack/pull/3604 broke multipart form data field parsing for the Files API since it changed its shape -- so as to match the API exactly to the OpenAI spec even in the generated client code. The underlying reason is that multipart/form-data cannot transport structured nested fields. Each field must be str-serialized. The client (specifically the OpenAI client whose behavior we must match), transports sub-fields as `expires_after[anchor]` and `expires_after[seconds]`, etc. We must be able to handle these fields somehow on the server without compromising the shape of the YAML spec. This PR "fixes" this by adding a dependency to convert the data. The main trade-off here is that we must add this `Depends()` annotation on every provider implementation for Files. This is a headache, but a much more reasonable one (in my opinion) given the alternatives. ## Test Plan Tests as shown in https://github.com/llamastack/llama-stack/pull/3604#issuecomment-3351090653 pass. --- .../providers/inline/files/localfs/files.py | 5 +- .../providers/remote/files/s3/files.py | 5 +- llama_stack/providers/utils/files/__init__.py | 5 + .../providers/utils/files/form_data.py | 69 +++++++ tests/unit/providers/utils/test_form_data.py | 179 ++++++++++++++++++ 5 files changed, 259 insertions(+), 4 deletions(-) create mode 100644 llama_stack/providers/utils/files/__init__.py create mode 100644 llama_stack/providers/utils/files/form_data.py create mode 100644 tests/unit/providers/utils/test_form_data.py diff --git a/llama_stack/providers/inline/files/localfs/files.py b/llama_stack/providers/inline/files/localfs/files.py index 6e0c72de3..be1da291a 100644 --- a/llama_stack/providers/inline/files/localfs/files.py +++ b/llama_stack/providers/inline/files/localfs/files.py @@ -9,7 +9,7 @@ import uuid from pathlib import Path from typing import Annotated -from fastapi import File, Form, Response, UploadFile +from fastapi import Depends, File, Form, Response, UploadFile from llama_stack.apis.common.errors import ResourceNotFoundError from llama_stack.apis.common.responses import Order @@ -23,6 +23,7 @@ from llama_stack.apis.files import ( ) from llama_stack.core.datatypes import AccessRule from llama_stack.log import get_logger +from llama_stack.providers.utils.files.form_data import parse_expires_after from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl @@ -87,7 +88,7 @@ class LocalfsFilesImpl(Files): self, file: Annotated[UploadFile, File()], purpose: Annotated[OpenAIFilePurpose, Form()], - expires_after: Annotated[ExpiresAfter | None, Form()] = None, + expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None, ) -> OpenAIFileObject: """Upload a file that can be used across various endpoints.""" if not self.sql_store: diff --git a/llama_stack/providers/remote/files/s3/files.py b/llama_stack/providers/remote/files/s3/files.py index 8520f70b6..eb339b31e 100644 --- a/llama_stack/providers/remote/files/s3/files.py +++ b/llama_stack/providers/remote/files/s3/files.py @@ -10,7 +10,7 @@ from typing import Annotated, Any import boto3 from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError -from fastapi import File, Form, Response, UploadFile +from fastapi import Depends, File, Form, Response, UploadFile from llama_stack.apis.common.errors import ResourceNotFoundError from llama_stack.apis.common.responses import Order @@ -23,6 +23,7 @@ from llama_stack.apis.files import ( OpenAIFilePurpose, ) from llama_stack.core.datatypes import AccessRule +from llama_stack.providers.utils.files.form_data import parse_expires_after from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl @@ -195,7 +196,7 @@ class S3FilesImpl(Files): self, file: Annotated[UploadFile, File()], purpose: Annotated[OpenAIFilePurpose, Form()], - expires_after: Annotated[ExpiresAfter | None, Form()] = None, + expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None, ) -> OpenAIFileObject: file_id = f"file-{uuid.uuid4().hex}" diff --git a/llama_stack/providers/utils/files/__init__.py b/llama_stack/providers/utils/files/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/utils/files/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/utils/files/form_data.py b/llama_stack/providers/utils/files/form_data.py new file mode 100644 index 000000000..3d8fb6d85 --- /dev/null +++ b/llama_stack/providers/utils/files/form_data.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json + +from fastapi import Request +from pydantic import BaseModel, ValidationError + +from llama_stack.apis.files import ExpiresAfter + + +async def parse_pydantic_from_form[T: BaseModel](request: Request, field_name: str, model_class: type[T]) -> T | None: + """ + Generic parser to extract a Pydantic model from multipart form data. + Handles both bracket notation (field[attr1], field[attr2]) and JSON string format. + + Args: + request: The FastAPI request object + field_name: The name of the field in the form data (e.g., "expires_after") + model_class: The Pydantic model class to parse into + + Returns: + An instance of model_class if parsing succeeds, None otherwise + + Example: + expires_after = await parse_pydantic_from_form( + request, "expires_after", ExpiresAfter + ) + """ + form = await request.form() + + # Check for bracket notation first (e.g., expires_after[anchor], expires_after[seconds]) + bracket_data = {} + prefix = f"{field_name}[" + for key in form.keys(): + if key.startswith(prefix) and key.endswith("]"): + # Extract the attribute name from field_name[attr] + attr = key[len(prefix) : -1] + bracket_data[attr] = form[key] + + if bracket_data: + try: + return model_class(**bracket_data) + except (ValidationError, TypeError): + pass + + # Check for JSON string format + if field_name in form: + value = form[field_name] + if isinstance(value, str): + try: + data = json.loads(value) + return model_class(**data) + except (json.JSONDecodeError, TypeError, ValidationError): + pass + + return None + + +async def parse_expires_after(request: Request) -> ExpiresAfter | None: + """ + Dependency to parse expires_after from multipart form data. + Handles both bracket notation (expires_after[anchor], expires_after[seconds]) + and JSON string format. + """ + return await parse_pydantic_from_form(request, "expires_after", ExpiresAfter) diff --git a/tests/unit/providers/utils/test_form_data.py b/tests/unit/providers/utils/test_form_data.py new file mode 100644 index 000000000..a27ba4be7 --- /dev/null +++ b/tests/unit/providers/utils/test_form_data.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from unittest.mock import AsyncMock, MagicMock + +from pydantic import BaseModel + +from llama_stack.providers.utils.files.form_data import ( + parse_expires_after, + parse_pydantic_from_form, +) + + +class _TestModel(BaseModel): + """Simple test model for generic parsing tests.""" + + name: str + value: int + + +async def test_parse_pydantic_from_form_bracket_notation(): + """Test parsing a Pydantic model using bracket notation.""" + # Create mock request with form data + mock_request = MagicMock() + mock_form = { + "test_field[name]": "test_name", + "test_field[value]": "42", + } + mock_request.form = AsyncMock(return_value=mock_form) + + result = await parse_pydantic_from_form(mock_request, "test_field", _TestModel) + + assert result is not None + assert result.name == "test_name" + assert result.value == 42 + + +async def test_parse_pydantic_from_form_json_string(): + """Test parsing a Pydantic model from JSON string.""" + # Create mock request with form data + mock_request = MagicMock() + test_data = {"name": "test_name", "value": 42} + mock_form = { + "test_field": json.dumps(test_data), + } + mock_request.form = AsyncMock(return_value=mock_form) + + result = await parse_pydantic_from_form(mock_request, "test_field", _TestModel) + + assert result is not None + assert result.name == "test_name" + assert result.value == 42 + + +async def test_parse_pydantic_from_form_bracket_takes_precedence(): + """Test that bracket notation takes precedence over JSON string.""" + # Create mock request with both formats + mock_request = MagicMock() + mock_form = { + "test_field[name]": "bracket_name", + "test_field[value]": "100", + "test_field": json.dumps({"name": "json_name", "value": 50}), + } + mock_request.form = AsyncMock(return_value=mock_form) + + result = await parse_pydantic_from_form(mock_request, "test_field", _TestModel) + + assert result is not None + # Bracket notation should win + assert result.name == "bracket_name" + assert result.value == 100 + + +async def test_parse_pydantic_from_form_missing_field(): + """Test that None is returned when field is missing.""" + # Create mock request with empty form + mock_request = MagicMock() + mock_form = {} + mock_request.form = AsyncMock(return_value=mock_form) + + result = await parse_pydantic_from_form(mock_request, "test_field", _TestModel) + + assert result is None + + +async def test_parse_pydantic_from_form_invalid_json(): + """Test that None is returned for invalid JSON.""" + # Create mock request with invalid JSON + mock_request = MagicMock() + mock_form = { + "test_field": "not valid json", + } + mock_request.form = AsyncMock(return_value=mock_form) + + result = await parse_pydantic_from_form(mock_request, "test_field", _TestModel) + + assert result is None + + +async def test_parse_pydantic_from_form_invalid_data(): + """Test that None is returned when data doesn't match model.""" + # Create mock request with data that doesn't match the model + mock_request = MagicMock() + mock_form = { + "test_field[wrong_field]": "value", + } + mock_request.form = AsyncMock(return_value=mock_form) + + result = await parse_pydantic_from_form(mock_request, "test_field", _TestModel) + + assert result is None + + +async def test_parse_expires_after_bracket_notation(): + """Test parsing expires_after using bracket notation.""" + # Create mock request with form data + mock_request = MagicMock() + mock_form = { + "expires_after[anchor]": "created_at", + "expires_after[seconds]": "3600", + } + mock_request.form = AsyncMock(return_value=mock_form) + + result = await parse_expires_after(mock_request) + + assert result is not None + assert result.anchor == "created_at" + assert result.seconds == 3600 + + +async def test_parse_expires_after_json_string(): + """Test parsing expires_after from JSON string.""" + # Create mock request with form data + mock_request = MagicMock() + expires_data = {"anchor": "created_at", "seconds": 7200} + mock_form = { + "expires_after": json.dumps(expires_data), + } + mock_request.form = AsyncMock(return_value=mock_form) + + result = await parse_expires_after(mock_request) + + assert result is not None + assert result.anchor == "created_at" + assert result.seconds == 7200 + + +async def test_parse_expires_after_missing(): + """Test that None is returned when expires_after is missing.""" + # Create mock request with empty form + mock_request = MagicMock() + mock_form = {} + mock_request.form = AsyncMock(return_value=mock_form) + + result = await parse_expires_after(mock_request) + + assert result is None + + +async def test_parse_pydantic_from_form_type_conversion(): + """Test that bracket notation properly handles type conversion.""" + # Create mock request with string values that need conversion + mock_request = MagicMock() + mock_form = { + "test_field[name]": "test", + "test_field[value]": "999", # String that should be converted to int + } + mock_request.form = AsyncMock(return_value=mock_form) + + result = await parse_pydantic_from_form(mock_request, "test_field", _TestModel) + + assert result is not None + assert result.name == "test" + assert result.value == 999 + assert isinstance(result.value, int)