[Bug Fix] Add support for UploadFile on LLM Pass through endpoints (OpenAI, Azure etc) (#9853)

* http passthrough file handling

* fix make_multipart_http_request

* test_pass_through_file_operations

* unit tests for file handling
This commit is contained in:
Ishaan Jaff 2025-04-09 15:29:20 -07:00 committed by GitHub
parent 6ba3c4a4f8
commit 08a3620414
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 241 additions and 19 deletions

View file

@ -4,16 +4,26 @@ import json
import uuid
from base64 import b64encode
from datetime import datetime
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union
from urllib.parse import parse_qs, urlencode, urlparse
import httpx
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi import (
APIRouter,
Depends,
HTTPException,
Request,
Response,
UploadFile,
status,
)
from fastapi.responses import StreamingResponse
from starlette.datastructures import UploadFile as StarletteUploadFile
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.proxy._types import (
ConfigFieldInfo,
@ -358,6 +368,92 @@ class HttpPassThroughEndpointHelpers:
)
return response
@staticmethod
async def non_streaming_http_request_handler(
request: Request,
async_client: httpx.AsyncClient,
url: httpx.URL,
headers: dict,
requested_query_params: Optional[dict] = None,
_parsed_body: Optional[dict] = None,
) -> httpx.Response:
"""
Handle non-streaming HTTP requests
Handles special cases when GET requests, multipart/form-data requests, and generic httpx requests
"""
if request.method == "GET":
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
params=requested_query_params,
)
elif HttpPassThroughEndpointHelpers.is_multipart(request) is True:
return await HttpPassThroughEndpointHelpers.make_multipart_http_request(
request=request,
async_client=async_client,
url=url,
headers=headers,
requested_query_params=requested_query_params,
)
else:
# Generic httpx method
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
params=requested_query_params,
json=_parsed_body,
)
return response
@staticmethod
def is_multipart(request: Request) -> bool:
"""Check if the request is a multipart/form-data request"""
return "multipart/form-data" in request.headers.get("content-type", "")
@staticmethod
async def _build_request_files_from_upload_file(
upload_file: Union[UploadFile, StarletteUploadFile],
) -> Tuple[Optional[str], bytes, Optional[str]]:
"""Build a request files dict from an UploadFile object"""
file_content = await upload_file.read()
return (upload_file.filename, file_content, upload_file.content_type)
@staticmethod
async def make_multipart_http_request(
request: Request,
async_client: httpx.AsyncClient,
url: httpx.URL,
headers: dict,
requested_query_params: Optional[dict] = None,
) -> httpx.Response:
"""Process multipart/form-data requests, handling both files and form fields"""
form_data = await request.form()
files = {}
form_data_dict = {}
for field_name, field_value in form_data.items():
if isinstance(field_value, (StarletteUploadFile, UploadFile)):
files[field_name] = (
await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file(
upload_file=field_value
)
)
else:
form_data_dict[field_name] = field_value
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
params=requested_query_params,
files=files,
data=form_data_dict,
)
return response
async def pass_through_request( # noqa: PLR0915
request: Request,
@ -424,7 +520,7 @@ async def pass_through_request( # noqa: PLR0915
start_time = datetime.now()
logging_obj = Logging(
model="unknown",
messages=[{"role": "user", "content": json.dumps(_parsed_body)}],
messages=[{"role": "user", "content": safe_dumps(_parsed_body)}],
stream=False,
call_type="pass_through_endpoint",
start_time=start_time,
@ -453,7 +549,6 @@ async def pass_through_request( # noqa: PLR0915
logging_obj.model_call_details["litellm_call_id"] = litellm_call_id
# combine url with query params for logging
requested_query_params: Optional[dict] = (
query_params or request.query_params.__dict__
)
@ -474,7 +569,7 @@ async def pass_through_request( # noqa: PLR0915
logging_url = str(url) + "?" + requested_query_params_str
logging_obj.pre_call(
input=[{"role": "user", "content": json.dumps(_parsed_body)}],
input=[{"role": "user", "content": safe_dumps(_parsed_body)}],
api_key="",
additional_args={
"complete_input_dict": _parsed_body,
@ -525,22 +620,16 @@ async def pass_through_request( # noqa: PLR0915
)
verbose_proxy_logger.debug("request body: {}".format(_parsed_body))
if request.method == "GET":
response = await async_client.request(
method=request.method,
response = (
await HttpPassThroughEndpointHelpers.non_streaming_http_request_handler(
request=request,
async_client=async_client,
url=url,
headers=headers,
params=requested_query_params,
requested_query_params=requested_query_params,
_parsed_body=_parsed_body,
)
else:
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
params=requested_query_params,
json=_parsed_body,
)
)
verbose_proxy_logger.debug("response.headers= %s", response.headers)
if _is_streaming_response(response) is True:

View file

@ -0,0 +1,116 @@
import json
import os
import sys
from io import BytesIO
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from fastapi import Request, UploadFile
from fastapi.testclient import TestClient
from starlette.datastructures import Headers
from starlette.datastructures import UploadFile as StarletteUploadFile
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
HttpPassThroughEndpointHelpers,
)
# Test is_multipart
def test_is_multipart():
# Test with multipart content type
request = MagicMock(spec=Request)
request.headers = Headers({"content-type": "multipart/form-data; boundary=123"})
assert HttpPassThroughEndpointHelpers.is_multipart(request) is True
# Test with non-multipart content type
request.headers = Headers({"content-type": "application/json"})
assert HttpPassThroughEndpointHelpers.is_multipart(request) is False
# Test with no content type
request.headers = Headers({})
assert HttpPassThroughEndpointHelpers.is_multipart(request) is False
# Test _build_request_files_from_upload_file
@pytest.mark.asyncio
async def test_build_request_files_from_upload_file():
# Test with FastAPI UploadFile
file_content = b"test content"
file = BytesIO(file_content)
# Create SpooledTemporaryFile with content type headers
headers = {"content-type": "text/plain"}
upload_file = UploadFile(file=file, filename="test.txt", headers=headers)
upload_file.read = AsyncMock(return_value=file_content)
result = await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file(
upload_file
)
assert result == ("test.txt", file_content, "text/plain")
# Test with Starlette UploadFile
file2 = BytesIO(file_content)
starlette_file = StarletteUploadFile(
file=file2,
filename="test2.txt",
headers=Headers({"content-type": "text/plain"}),
)
starlette_file.read = AsyncMock(return_value=file_content)
result = await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file(
starlette_file
)
assert result == ("test2.txt", file_content, "text/plain")
# Test make_multipart_http_request
@pytest.mark.asyncio
async def test_make_multipart_http_request():
# Mock request with file and form field
request = MagicMock(spec=Request)
request.method = "POST"
# Mock form data
file_content = b"test file content"
file = BytesIO(file_content)
# Create SpooledTemporaryFile with content type headers
headers = {"content-type": "text/plain"}
upload_file = UploadFile(file=file, filename="test.txt", headers=headers)
upload_file.read = AsyncMock(return_value=file_content)
form_data = {"file": upload_file, "text_field": "test value"}
request.form = AsyncMock(return_value=form_data)
# Mock httpx client
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {}
async_client = MagicMock()
async_client.request = AsyncMock(return_value=mock_response)
# Test the function
response = await HttpPassThroughEndpointHelpers.make_multipart_http_request(
request=request,
async_client=async_client,
url=httpx.URL("http://test.com"),
headers={},
requested_query_params=None,
)
# Verify the response
assert response == mock_response
# Verify the client call
async_client.request.assert_called_once()
call_args = async_client.request.call_args[1]
assert call_args["method"] == "POST"
assert str(call_args["url"]) == "http://test.com"
assert isinstance(call_args["files"], dict)
assert isinstance(call_args["data"], dict)
assert call_args["data"]["text_field"] == "test value"

View file

@ -2,14 +2,31 @@ import pytest
import openai
import aiohttp
import asyncio
import tempfile
from typing_extensions import override
from openai import AssistantEventHandler
client = openai.OpenAI(base_url="http://0.0.0.0:4000/openai", api_key="sk-1234")
def test_pass_through_file_operations():
# Create a temporary file
with tempfile.NamedTemporaryFile(mode='w+', suffix='.txt', delete=False) as temp_file:
temp_file.write("This is a test file for the OpenAI Assistants API.")
temp_file.flush()
# create a file
file = client.files.create(
file=open(temp_file.name, "rb"),
purpose="assistants",
)
print("file created", file)
# delete the file
delete_file = client.files.delete(file.id)
print("file deleted", delete_file)
def test_openai_assistants_e2e_operations():
assistant = client.beta.assistants.create(
name="Math Tutor",
instructions="You are a personal math tutor. Write and run code to answer math questions.",