fix(expires_after): make sure multipart/form-data is properly parsed (#3612)

https://github.com/llamastack/llama-stack/pull/3604 broke multipart form
data field parsing for the Files API since it changed its shape -- so as
to match the API exactly to the OpenAI spec even in the generated client
code.

The underlying reason is that multipart/form-data cannot transport
structured nested fields. Each field must be str-serialized. The client
(specifically the OpenAI client whose behavior we must match),
transports sub-fields as `expires_after[anchor]` and
`expires_after[seconds]`, etc. We must be able to handle these fields
somehow on the server without compromising the shape of the YAML spec.

This PR "fixes" this by adding a dependency to convert the data. The
main trade-off here is that we must add this `Depends()` annotation on
every provider implementation for Files. This is a headache, but a much
more reasonable one (in my opinion) given the alternatives.

## Test Plan

Tests as shown in
https://github.com/llamastack/llama-stack/pull/3604#issuecomment-3351090653
pass.
This commit is contained in:
Ashwin Bharambe 2025-09-30 13:14:03 -07:00 committed by GitHub
parent 73de235ef1
commit 606f4cf281
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 259 additions and 4 deletions

View file

@ -9,7 +9,7 @@ import uuid
from pathlib import Path from pathlib import Path
from typing import Annotated from typing import Annotated
from fastapi import File, Form, Response, UploadFile from fastapi import Depends, File, Form, Response, UploadFile
from llama_stack.apis.common.errors import ResourceNotFoundError from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.common.responses import Order from llama_stack.apis.common.responses import Order
@ -23,6 +23,7 @@ from llama_stack.apis.files import (
) )
from llama_stack.core.datatypes import AccessRule from llama_stack.core.datatypes import AccessRule
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
@ -87,7 +88,7 @@ class LocalfsFilesImpl(Files):
self, self,
file: Annotated[UploadFile, File()], file: Annotated[UploadFile, File()],
purpose: Annotated[OpenAIFilePurpose, Form()], purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after: Annotated[ExpiresAfter | None, Form()] = None, expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
) -> OpenAIFileObject: ) -> OpenAIFileObject:
"""Upload a file that can be used across various endpoints.""" """Upload a file that can be used across various endpoints."""
if not self.sql_store: if not self.sql_store:

View file

@ -10,7 +10,7 @@ from typing import Annotated, Any
import boto3 import boto3
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
from fastapi import File, Form, Response, UploadFile from fastapi import Depends, File, Form, Response, UploadFile
from llama_stack.apis.common.errors import ResourceNotFoundError from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.common.responses import Order from llama_stack.apis.common.responses import Order
@ -23,6 +23,7 @@ from llama_stack.apis.files import (
OpenAIFilePurpose, OpenAIFilePurpose,
) )
from llama_stack.core.datatypes import AccessRule from llama_stack.core.datatypes import AccessRule
from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
@ -195,7 +196,7 @@ class S3FilesImpl(Files):
self, self,
file: Annotated[UploadFile, File()], file: Annotated[UploadFile, File()],
purpose: Annotated[OpenAIFilePurpose, Form()], purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after: Annotated[ExpiresAfter | None, Form()] = None, expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
) -> OpenAIFileObject: ) -> OpenAIFileObject:
file_id = f"file-{uuid.uuid4().hex}" file_id = f"file-{uuid.uuid4().hex}"

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,69 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from fastapi import Request
from pydantic import BaseModel, ValidationError
from llama_stack.apis.files import ExpiresAfter
async def parse_pydantic_from_form[T: BaseModel](request: Request, field_name: str, model_class: type[T]) -> T | None:
"""
Generic parser to extract a Pydantic model from multipart form data.
Handles both bracket notation (field[attr1], field[attr2]) and JSON string format.
Args:
request: The FastAPI request object
field_name: The name of the field in the form data (e.g., "expires_after")
model_class: The Pydantic model class to parse into
Returns:
An instance of model_class if parsing succeeds, None otherwise
Example:
expires_after = await parse_pydantic_from_form(
request, "expires_after", ExpiresAfter
)
"""
form = await request.form()
# Check for bracket notation first (e.g., expires_after[anchor], expires_after[seconds])
bracket_data = {}
prefix = f"{field_name}["
for key in form.keys():
if key.startswith(prefix) and key.endswith("]"):
# Extract the attribute name from field_name[attr]
attr = key[len(prefix) : -1]
bracket_data[attr] = form[key]
if bracket_data:
try:
return model_class(**bracket_data)
except (ValidationError, TypeError):
pass
# Check for JSON string format
if field_name in form:
value = form[field_name]
if isinstance(value, str):
try:
data = json.loads(value)
return model_class(**data)
except (json.JSONDecodeError, TypeError, ValidationError):
pass
return None
async def parse_expires_after(request: Request) -> ExpiresAfter | None:
"""
Dependency to parse expires_after from multipart form data.
Handles both bracket notation (expires_after[anchor], expires_after[seconds])
and JSON string format.
"""
return await parse_pydantic_from_form(request, "expires_after", ExpiresAfter)

View file

@ -0,0 +1,179 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from unittest.mock import AsyncMock, MagicMock
from pydantic import BaseModel
from llama_stack.providers.utils.files.form_data import (
parse_expires_after,
parse_pydantic_from_form,
)
class _TestModel(BaseModel):
"""Simple test model for generic parsing tests."""
name: str
value: int
async def test_parse_pydantic_from_form_bracket_notation():
"""Test parsing a Pydantic model using bracket notation."""
# Create mock request with form data
mock_request = MagicMock()
mock_form = {
"test_field[name]": "test_name",
"test_field[value]": "42",
}
mock_request.form = AsyncMock(return_value=mock_form)
result = await parse_pydantic_from_form(mock_request, "test_field", _TestModel)
assert result is not None
assert result.name == "test_name"
assert result.value == 42
async def test_parse_pydantic_from_form_json_string():
"""Test parsing a Pydantic model from JSON string."""
# Create mock request with form data
mock_request = MagicMock()
test_data = {"name": "test_name", "value": 42}
mock_form = {
"test_field": json.dumps(test_data),
}
mock_request.form = AsyncMock(return_value=mock_form)
result = await parse_pydantic_from_form(mock_request, "test_field", _TestModel)
assert result is not None
assert result.name == "test_name"
assert result.value == 42
async def test_parse_pydantic_from_form_bracket_takes_precedence():
"""Test that bracket notation takes precedence over JSON string."""
# Create mock request with both formats
mock_request = MagicMock()
mock_form = {
"test_field[name]": "bracket_name",
"test_field[value]": "100",
"test_field": json.dumps({"name": "json_name", "value": 50}),
}
mock_request.form = AsyncMock(return_value=mock_form)
result = await parse_pydantic_from_form(mock_request, "test_field", _TestModel)
assert result is not None
# Bracket notation should win
assert result.name == "bracket_name"
assert result.value == 100
async def test_parse_pydantic_from_form_missing_field():
"""Test that None is returned when field is missing."""
# Create mock request with empty form
mock_request = MagicMock()
mock_form = {}
mock_request.form = AsyncMock(return_value=mock_form)
result = await parse_pydantic_from_form(mock_request, "test_field", _TestModel)
assert result is None
async def test_parse_pydantic_from_form_invalid_json():
"""Test that None is returned for invalid JSON."""
# Create mock request with invalid JSON
mock_request = MagicMock()
mock_form = {
"test_field": "not valid json",
}
mock_request.form = AsyncMock(return_value=mock_form)
result = await parse_pydantic_from_form(mock_request, "test_field", _TestModel)
assert result is None
async def test_parse_pydantic_from_form_invalid_data():
"""Test that None is returned when data doesn't match model."""
# Create mock request with data that doesn't match the model
mock_request = MagicMock()
mock_form = {
"test_field[wrong_field]": "value",
}
mock_request.form = AsyncMock(return_value=mock_form)
result = await parse_pydantic_from_form(mock_request, "test_field", _TestModel)
assert result is None
async def test_parse_expires_after_bracket_notation():
"""Test parsing expires_after using bracket notation."""
# Create mock request with form data
mock_request = MagicMock()
mock_form = {
"expires_after[anchor]": "created_at",
"expires_after[seconds]": "3600",
}
mock_request.form = AsyncMock(return_value=mock_form)
result = await parse_expires_after(mock_request)
assert result is not None
assert result.anchor == "created_at"
assert result.seconds == 3600
async def test_parse_expires_after_json_string():
"""Test parsing expires_after from JSON string."""
# Create mock request with form data
mock_request = MagicMock()
expires_data = {"anchor": "created_at", "seconds": 7200}
mock_form = {
"expires_after": json.dumps(expires_data),
}
mock_request.form = AsyncMock(return_value=mock_form)
result = await parse_expires_after(mock_request)
assert result is not None
assert result.anchor == "created_at"
assert result.seconds == 7200
async def test_parse_expires_after_missing():
"""Test that None is returned when expires_after is missing."""
# Create mock request with empty form
mock_request = MagicMock()
mock_form = {}
mock_request.form = AsyncMock(return_value=mock_form)
result = await parse_expires_after(mock_request)
assert result is None
async def test_parse_pydantic_from_form_type_conversion():
"""Test that bracket notation properly handles type conversion."""
# Create mock request with string values that need conversion
mock_request = MagicMock()
mock_form = {
"test_field[name]": "test",
"test_field[value]": "999", # String that should be converted to int
}
mock_request.form = AsyncMock(return_value=mock_form)
result = await parse_pydantic_from_form(mock_request, "test_field", _TestModel)
assert result is not None
assert result.name == "test"
assert result.value == 999
assert isinstance(result.value, int)