diff --git a/llama_stack/providers/utils/files/form_data.py b/llama_stack/providers/utils/files/form_data.py index 0e3cfc8d1..3d8fb6d85 100644 --- a/llama_stack/providers/utils/files/form_data.py +++ b/llama_stack/providers/utils/files/form_data.py @@ -4,38 +4,66 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json + from fastapi import Request +from pydantic import BaseModel, ValidationError from llama_stack.apis.files import ExpiresAfter +async def parse_pydantic_from_form[T: BaseModel](request: Request, field_name: str, model_class: type[T]) -> T | None: + """ + Generic parser to extract a Pydantic model from multipart form data. + Handles both bracket notation (field[attr1], field[attr2]) and JSON string format. + + Args: + request: The FastAPI request object + field_name: The name of the field in the form data (e.g., "expires_after") + model_class: The Pydantic model class to parse into + + Returns: + An instance of model_class if parsing succeeds, None otherwise + + Example: + expires_after = await parse_pydantic_from_form( + request, "expires_after", ExpiresAfter + ) + """ + form = await request.form() + + # Check for bracket notation first (e.g., expires_after[anchor], expires_after[seconds]) + bracket_data = {} + prefix = f"{field_name}[" + for key in form.keys(): + if key.startswith(prefix) and key.endswith("]"): + # Extract the attribute name from field_name[attr] + attr = key[len(prefix) : -1] + bracket_data[attr] = form[key] + + if bracket_data: + try: + return model_class(**bracket_data) + except (ValidationError, TypeError): + pass + + # Check for JSON string format + if field_name in form: + value = form[field_name] + if isinstance(value, str): + try: + data = json.loads(value) + return model_class(**data) + except (json.JSONDecodeError, TypeError, ValidationError): + pass + + return None + + async def parse_expires_after(request: Request) -> ExpiresAfter | None: """ Dependency to parse expires_after from multipart form data. Handles both bracket notation (expires_after[anchor], expires_after[seconds]) and JSON string format. """ - form = await request.form() - - # Check for bracket notation first - anchor_key = "expires_after[anchor]" - seconds_key = "expires_after[seconds]" - - if anchor_key in form and seconds_key in form: - anchor = form[anchor_key] - seconds = form[seconds_key] - return ExpiresAfter(anchor=anchor, seconds=int(seconds)) - - # Check for JSON string format - if "expires_after" in form: - value = form["expires_after"] - if isinstance(value, str): - import json - - try: - data = json.loads(value) - return ExpiresAfter(**data) - except (json.JSONDecodeError, TypeError): - pass - - return None + return await parse_pydantic_from_form(request, "expires_after", ExpiresAfter) diff --git a/tests/unit/providers/utils/test_form_data.py b/tests/unit/providers/utils/test_form_data.py new file mode 100644 index 000000000..a27ba4be7 --- /dev/null +++ b/tests/unit/providers/utils/test_form_data.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from unittest.mock import AsyncMock, MagicMock + +from pydantic import BaseModel + +from llama_stack.providers.utils.files.form_data import ( + parse_expires_after, + parse_pydantic_from_form, +) + + +class _TestModel(BaseModel): + """Simple test model for generic parsing tests.""" + + name: str + value: int + + +async def test_parse_pydantic_from_form_bracket_notation(): + """Test parsing a Pydantic model using bracket notation.""" + # Create mock request with form data + mock_request = MagicMock() + mock_form = { + "test_field[name]": "test_name", + "test_field[value]": "42", + } + mock_request.form = AsyncMock(return_value=mock_form) + + result = await parse_pydantic_from_form(mock_request, "test_field", _TestModel) + + assert result is not None + assert result.name == "test_name" + assert result.value == 42 + + +async def test_parse_pydantic_from_form_json_string(): + """Test parsing a Pydantic model from JSON string.""" + # Create mock request with form data + mock_request = MagicMock() + test_data = {"name": "test_name", "value": 42} + mock_form = { + "test_field": json.dumps(test_data), + } + mock_request.form = AsyncMock(return_value=mock_form) + + result = await parse_pydantic_from_form(mock_request, "test_field", _TestModel) + + assert result is not None + assert result.name == "test_name" + assert result.value == 42 + + +async def test_parse_pydantic_from_form_bracket_takes_precedence(): + """Test that bracket notation takes precedence over JSON string.""" + # Create mock request with both formats + mock_request = MagicMock() + mock_form = { + "test_field[name]": "bracket_name", + "test_field[value]": "100", + "test_field": json.dumps({"name": "json_name", "value": 50}), + } + mock_request.form = AsyncMock(return_value=mock_form) + + result = await parse_pydantic_from_form(mock_request, "test_field", _TestModel) + + assert result is not None + # Bracket notation should win + assert result.name == "bracket_name" + assert result.value == 100 + + +async def test_parse_pydantic_from_form_missing_field(): + """Test that None is returned when field is missing.""" + # Create mock request with empty form + mock_request = MagicMock() + mock_form = {} + mock_request.form = AsyncMock(return_value=mock_form) + + result = await parse_pydantic_from_form(mock_request, "test_field", _TestModel) + + assert result is None + + +async def test_parse_pydantic_from_form_invalid_json(): + """Test that None is returned for invalid JSON.""" + # Create mock request with invalid JSON + mock_request = MagicMock() + mock_form = { + "test_field": "not valid json", + } + mock_request.form = AsyncMock(return_value=mock_form) + + result = await parse_pydantic_from_form(mock_request, "test_field", _TestModel) + + assert result is None + + +async def test_parse_pydantic_from_form_invalid_data(): + """Test that None is returned when data doesn't match model.""" + # Create mock request with data that doesn't match the model + mock_request = MagicMock() + mock_form = { + "test_field[wrong_field]": "value", + } + mock_request.form = AsyncMock(return_value=mock_form) + + result = await parse_pydantic_from_form(mock_request, "test_field", _TestModel) + + assert result is None + + +async def test_parse_expires_after_bracket_notation(): + """Test parsing expires_after using bracket notation.""" + # Create mock request with form data + mock_request = MagicMock() + mock_form = { + "expires_after[anchor]": "created_at", + "expires_after[seconds]": "3600", + } + mock_request.form = AsyncMock(return_value=mock_form) + + result = await parse_expires_after(mock_request) + + assert result is not None + assert result.anchor == "created_at" + assert result.seconds == 3600 + + +async def test_parse_expires_after_json_string(): + """Test parsing expires_after from JSON string.""" + # Create mock request with form data + mock_request = MagicMock() + expires_data = {"anchor": "created_at", "seconds": 7200} + mock_form = { + "expires_after": json.dumps(expires_data), + } + mock_request.form = AsyncMock(return_value=mock_form) + + result = await parse_expires_after(mock_request) + + assert result is not None + assert result.anchor == "created_at" + assert result.seconds == 7200 + + +async def test_parse_expires_after_missing(): + """Test that None is returned when expires_after is missing.""" + # Create mock request with empty form + mock_request = MagicMock() + mock_form = {} + mock_request.form = AsyncMock(return_value=mock_form) + + result = await parse_expires_after(mock_request) + + assert result is None + + +async def test_parse_pydantic_from_form_type_conversion(): + """Test that bracket notation properly handles type conversion.""" + # Create mock request with string values that need conversion + mock_request = MagicMock() + mock_form = { + "test_field[name]": "test", + "test_field[value]": "999", # String that should be converted to int + } + mock_request.form = AsyncMock(return_value=mock_form) + + result = await parse_pydantic_from_form(mock_request, "test_field", _TestModel) + + assert result is not None + assert result.name == "test" + assert result.value == 999 + assert isinstance(result.value, int)