[Bug Fix] Add support for UploadFile on LLM Pass through endpoints (OpenAI, Azure etc) (#9853)

* http passthrough file handling

* fix make_multipart_http_request

* test_pass_through_file_operations

* unit tests for file handling
This commit is contained in:
Ishaan Jaff 2025-04-09 15:29:20 -07:00 committed by GitHub
parent 6ba3c4a4f8
commit 08a3620414
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 241 additions and 19 deletions

View file

@ -4,16 +4,26 @@ import json
import uuid
from base64 import b64encode
from datetime import datetime
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union
from urllib.parse import parse_qs, urlencode, urlparse
import httpx
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi import (
APIRouter,
Depends,
HTTPException,
Request,
Response,
UploadFile,
status,
)
from fastapi.responses import StreamingResponse
from starlette.datastructures import UploadFile as StarletteUploadFile
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.proxy._types import (
ConfigFieldInfo,
@ -358,6 +368,92 @@ class HttpPassThroughEndpointHelpers:
)
return response
@staticmethod
async def non_streaming_http_request_handler(
request: Request,
async_client: httpx.AsyncClient,
url: httpx.URL,
headers: dict,
requested_query_params: Optional[dict] = None,
_parsed_body: Optional[dict] = None,
) -> httpx.Response:
"""
Handle non-streaming HTTP requests
Handles special cases when GET requests, multipart/form-data requests, and generic httpx requests
"""
if request.method == "GET":
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
params=requested_query_params,
)
elif HttpPassThroughEndpointHelpers.is_multipart(request) is True:
return await HttpPassThroughEndpointHelpers.make_multipart_http_request(
request=request,
async_client=async_client,
url=url,
headers=headers,
requested_query_params=requested_query_params,
)
else:
# Generic httpx method
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
params=requested_query_params,
json=_parsed_body,
)
return response
@staticmethod
def is_multipart(request: Request) -> bool:
"""Check if the request is a multipart/form-data request"""
return "multipart/form-data" in request.headers.get("content-type", "")
@staticmethod
async def _build_request_files_from_upload_file(
upload_file: Union[UploadFile, StarletteUploadFile],
) -> Tuple[Optional[str], bytes, Optional[str]]:
"""Build a request files dict from an UploadFile object"""
file_content = await upload_file.read()
return (upload_file.filename, file_content, upload_file.content_type)
@staticmethod
async def make_multipart_http_request(
request: Request,
async_client: httpx.AsyncClient,
url: httpx.URL,
headers: dict,
requested_query_params: Optional[dict] = None,
) -> httpx.Response:
"""Process multipart/form-data requests, handling both files and form fields"""
form_data = await request.form()
files = {}
form_data_dict = {}
for field_name, field_value in form_data.items():
if isinstance(field_value, (StarletteUploadFile, UploadFile)):
files[field_name] = (
await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file(
upload_file=field_value
)
)
else:
form_data_dict[field_name] = field_value
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
params=requested_query_params,
files=files,
data=form_data_dict,
)
return response
async def pass_through_request( # noqa: PLR0915
request: Request,
@ -424,7 +520,7 @@ async def pass_through_request( # noqa: PLR0915
start_time = datetime.now()
logging_obj = Logging(
model="unknown",
messages=[{"role": "user", "content": json.dumps(_parsed_body)}],
messages=[{"role": "user", "content": safe_dumps(_parsed_body)}],
stream=False,
call_type="pass_through_endpoint",
start_time=start_time,
@ -453,7 +549,6 @@ async def pass_through_request( # noqa: PLR0915
logging_obj.model_call_details["litellm_call_id"] = litellm_call_id
# combine url with query params for logging
requested_query_params: Optional[dict] = (
query_params or request.query_params.__dict__
)
@ -474,7 +569,7 @@ async def pass_through_request( # noqa: PLR0915
logging_url = str(url) + "?" + requested_query_params_str
logging_obj.pre_call(
input=[{"role": "user", "content": json.dumps(_parsed_body)}],
input=[{"role": "user", "content": safe_dumps(_parsed_body)}],
api_key="",
additional_args={
"complete_input_dict": _parsed_body,
@ -525,22 +620,16 @@ async def pass_through_request( # noqa: PLR0915
)
verbose_proxy_logger.debug("request body: {}".format(_parsed_body))
if request.method == "GET":
response = await async_client.request(
method=request.method,
response = (
await HttpPassThroughEndpointHelpers.non_streaming_http_request_handler(
request=request,
async_client=async_client,
url=url,
headers=headers,
params=requested_query_params,
requested_query_params=requested_query_params,
_parsed_body=_parsed_body,
)
else:
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
params=requested_query_params,
json=_parsed_body,
)
)
verbose_proxy_logger.debug("response.headers= %s", response.headers)
if _is_streaming_response(response) is True: