mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
[Bug Fix] Add support for UploadFile on LLM Pass through endpoints (OpenAI, Azure etc) (#9853)
* http passthrough file handling * fix make_multipart_http_request * test_pass_through_file_operations * unit tests for file handling
This commit is contained in:
parent
6ba3c4a4f8
commit
08a3620414
3 changed files with 241 additions and 19 deletions
|
@ -4,16 +4,26 @@ import json
|
|||
import uuid
|
||||
from base64 import b64encode
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from urllib.parse import parse_qs, urlencode, urlparse
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
HTTPException,
|
||||
Request,
|
||||
Response,
|
||||
UploadFile,
|
||||
status,
|
||||
)
|
||||
from fastapi.responses import StreamingResponse
|
||||
from starlette.datastructures import UploadFile as StarletteUploadFile
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
from litellm.proxy._types import (
|
||||
ConfigFieldInfo,
|
||||
|
@ -358,6 +368,92 @@ class HttpPassThroughEndpointHelpers:
|
|||
)
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
async def non_streaming_http_request_handler(
|
||||
request: Request,
|
||||
async_client: httpx.AsyncClient,
|
||||
url: httpx.URL,
|
||||
headers: dict,
|
||||
requested_query_params: Optional[dict] = None,
|
||||
_parsed_body: Optional[dict] = None,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Handle non-streaming HTTP requests
|
||||
|
||||
Handles special cases when GET requests, multipart/form-data requests, and generic httpx requests
|
||||
"""
|
||||
if request.method == "GET":
|
||||
response = await async_client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=requested_query_params,
|
||||
)
|
||||
elif HttpPassThroughEndpointHelpers.is_multipart(request) is True:
|
||||
return await HttpPassThroughEndpointHelpers.make_multipart_http_request(
|
||||
request=request,
|
||||
async_client=async_client,
|
||||
url=url,
|
||||
headers=headers,
|
||||
requested_query_params=requested_query_params,
|
||||
)
|
||||
else:
|
||||
# Generic httpx method
|
||||
response = await async_client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=requested_query_params,
|
||||
json=_parsed_body,
|
||||
)
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def is_multipart(request: Request) -> bool:
|
||||
"""Check if the request is a multipart/form-data request"""
|
||||
return "multipart/form-data" in request.headers.get("content-type", "")
|
||||
|
||||
@staticmethod
|
||||
async def _build_request_files_from_upload_file(
|
||||
upload_file: Union[UploadFile, StarletteUploadFile],
|
||||
) -> Tuple[Optional[str], bytes, Optional[str]]:
|
||||
"""Build a request files dict from an UploadFile object"""
|
||||
file_content = await upload_file.read()
|
||||
return (upload_file.filename, file_content, upload_file.content_type)
|
||||
|
||||
@staticmethod
|
||||
async def make_multipart_http_request(
|
||||
request: Request,
|
||||
async_client: httpx.AsyncClient,
|
||||
url: httpx.URL,
|
||||
headers: dict,
|
||||
requested_query_params: Optional[dict] = None,
|
||||
) -> httpx.Response:
|
||||
"""Process multipart/form-data requests, handling both files and form fields"""
|
||||
form_data = await request.form()
|
||||
files = {}
|
||||
form_data_dict = {}
|
||||
|
||||
for field_name, field_value in form_data.items():
|
||||
if isinstance(field_value, (StarletteUploadFile, UploadFile)):
|
||||
files[field_name] = (
|
||||
await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file(
|
||||
upload_file=field_value
|
||||
)
|
||||
)
|
||||
else:
|
||||
form_data_dict[field_name] = field_value
|
||||
|
||||
response = await async_client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=requested_query_params,
|
||||
files=files,
|
||||
data=form_data_dict,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
async def pass_through_request( # noqa: PLR0915
|
||||
request: Request,
|
||||
|
@ -424,7 +520,7 @@ async def pass_through_request( # noqa: PLR0915
|
|||
start_time = datetime.now()
|
||||
logging_obj = Logging(
|
||||
model="unknown",
|
||||
messages=[{"role": "user", "content": json.dumps(_parsed_body)}],
|
||||
messages=[{"role": "user", "content": safe_dumps(_parsed_body)}],
|
||||
stream=False,
|
||||
call_type="pass_through_endpoint",
|
||||
start_time=start_time,
|
||||
|
@ -453,7 +549,6 @@ async def pass_through_request( # noqa: PLR0915
|
|||
logging_obj.model_call_details["litellm_call_id"] = litellm_call_id
|
||||
|
||||
# combine url with query params for logging
|
||||
|
||||
requested_query_params: Optional[dict] = (
|
||||
query_params or request.query_params.__dict__
|
||||
)
|
||||
|
@ -474,7 +569,7 @@ async def pass_through_request( # noqa: PLR0915
|
|||
logging_url = str(url) + "?" + requested_query_params_str
|
||||
|
||||
logging_obj.pre_call(
|
||||
input=[{"role": "user", "content": json.dumps(_parsed_body)}],
|
||||
input=[{"role": "user", "content": safe_dumps(_parsed_body)}],
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": _parsed_body,
|
||||
|
@ -525,22 +620,16 @@ async def pass_through_request( # noqa: PLR0915
|
|||
)
|
||||
verbose_proxy_logger.debug("request body: {}".format(_parsed_body))
|
||||
|
||||
if request.method == "GET":
|
||||
response = await async_client.request(
|
||||
method=request.method,
|
||||
response = (
|
||||
await HttpPassThroughEndpointHelpers.non_streaming_http_request_handler(
|
||||
request=request,
|
||||
async_client=async_client,
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=requested_query_params,
|
||||
requested_query_params=requested_query_params,
|
||||
_parsed_body=_parsed_body,
|
||||
)
|
||||
else:
|
||||
response = await async_client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=requested_query_params,
|
||||
json=_parsed_body,
|
||||
)
|
||||
|
||||
)
|
||||
verbose_proxy_logger.debug("response.headers= %s", response.headers)
|
||||
|
||||
if _is_streaming_response(response) is True:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue