mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
[Bug Fix] Add support for UploadFile on LLM Pass through endpoints (OpenAI, Azure etc) (#9853)
* http passthrough file handling * fix make_multipart_http_request * test_pass_through_file_operations * unit tests for file handling
This commit is contained in:
parent
6ba3c4a4f8
commit
08a3620414
3 changed files with 241 additions and 19 deletions
|
@ -4,16 +4,26 @@ import json
|
|||
import uuid
|
||||
from base64 import b64encode
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from urllib.parse import parse_qs, urlencode, urlparse
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
HTTPException,
|
||||
Request,
|
||||
Response,
|
||||
UploadFile,
|
||||
status,
|
||||
)
|
||||
from fastapi.responses import StreamingResponse
|
||||
from starlette.datastructures import UploadFile as StarletteUploadFile
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
from litellm.proxy._types import (
|
||||
ConfigFieldInfo,
|
||||
|
@ -358,6 +368,92 @@ class HttpPassThroughEndpointHelpers:
|
|||
)
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
async def non_streaming_http_request_handler(
|
||||
request: Request,
|
||||
async_client: httpx.AsyncClient,
|
||||
url: httpx.URL,
|
||||
headers: dict,
|
||||
requested_query_params: Optional[dict] = None,
|
||||
_parsed_body: Optional[dict] = None,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Handle non-streaming HTTP requests
|
||||
|
||||
Handles special cases when GET requests, multipart/form-data requests, and generic httpx requests
|
||||
"""
|
||||
if request.method == "GET":
|
||||
response = await async_client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=requested_query_params,
|
||||
)
|
||||
elif HttpPassThroughEndpointHelpers.is_multipart(request) is True:
|
||||
return await HttpPassThroughEndpointHelpers.make_multipart_http_request(
|
||||
request=request,
|
||||
async_client=async_client,
|
||||
url=url,
|
||||
headers=headers,
|
||||
requested_query_params=requested_query_params,
|
||||
)
|
||||
else:
|
||||
# Generic httpx method
|
||||
response = await async_client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=requested_query_params,
|
||||
json=_parsed_body,
|
||||
)
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def is_multipart(request: Request) -> bool:
|
||||
"""Check if the request is a multipart/form-data request"""
|
||||
return "multipart/form-data" in request.headers.get("content-type", "")
|
||||
|
||||
@staticmethod
|
||||
async def _build_request_files_from_upload_file(
|
||||
upload_file: Union[UploadFile, StarletteUploadFile],
|
||||
) -> Tuple[Optional[str], bytes, Optional[str]]:
|
||||
"""Build a request files dict from an UploadFile object"""
|
||||
file_content = await upload_file.read()
|
||||
return (upload_file.filename, file_content, upload_file.content_type)
|
||||
|
||||
@staticmethod
|
||||
async def make_multipart_http_request(
|
||||
request: Request,
|
||||
async_client: httpx.AsyncClient,
|
||||
url: httpx.URL,
|
||||
headers: dict,
|
||||
requested_query_params: Optional[dict] = None,
|
||||
) -> httpx.Response:
|
||||
"""Process multipart/form-data requests, handling both files and form fields"""
|
||||
form_data = await request.form()
|
||||
files = {}
|
||||
form_data_dict = {}
|
||||
|
||||
for field_name, field_value in form_data.items():
|
||||
if isinstance(field_value, (StarletteUploadFile, UploadFile)):
|
||||
files[field_name] = (
|
||||
await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file(
|
||||
upload_file=field_value
|
||||
)
|
||||
)
|
||||
else:
|
||||
form_data_dict[field_name] = field_value
|
||||
|
||||
response = await async_client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=requested_query_params,
|
||||
files=files,
|
||||
data=form_data_dict,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
async def pass_through_request( # noqa: PLR0915
|
||||
request: Request,
|
||||
|
@ -424,7 +520,7 @@ async def pass_through_request( # noqa: PLR0915
|
|||
start_time = datetime.now()
|
||||
logging_obj = Logging(
|
||||
model="unknown",
|
||||
messages=[{"role": "user", "content": json.dumps(_parsed_body)}],
|
||||
messages=[{"role": "user", "content": safe_dumps(_parsed_body)}],
|
||||
stream=False,
|
||||
call_type="pass_through_endpoint",
|
||||
start_time=start_time,
|
||||
|
@ -453,7 +549,6 @@ async def pass_through_request( # noqa: PLR0915
|
|||
logging_obj.model_call_details["litellm_call_id"] = litellm_call_id
|
||||
|
||||
# combine url with query params for logging
|
||||
|
||||
requested_query_params: Optional[dict] = (
|
||||
query_params or request.query_params.__dict__
|
||||
)
|
||||
|
@ -474,7 +569,7 @@ async def pass_through_request( # noqa: PLR0915
|
|||
logging_url = str(url) + "?" + requested_query_params_str
|
||||
|
||||
logging_obj.pre_call(
|
||||
input=[{"role": "user", "content": json.dumps(_parsed_body)}],
|
||||
input=[{"role": "user", "content": safe_dumps(_parsed_body)}],
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": _parsed_body,
|
||||
|
@ -525,22 +620,16 @@ async def pass_through_request( # noqa: PLR0915
|
|||
)
|
||||
verbose_proxy_logger.debug("request body: {}".format(_parsed_body))
|
||||
|
||||
if request.method == "GET":
|
||||
response = await async_client.request(
|
||||
method=request.method,
|
||||
response = (
|
||||
await HttpPassThroughEndpointHelpers.non_streaming_http_request_handler(
|
||||
request=request,
|
||||
async_client=async_client,
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=requested_query_params,
|
||||
requested_query_params=requested_query_params,
|
||||
_parsed_body=_parsed_body,
|
||||
)
|
||||
else:
|
||||
response = await async_client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=requested_query_params,
|
||||
json=_parsed_body,
|
||||
)
|
||||
|
||||
)
|
||||
verbose_proxy_logger.debug("response.headers= %s", response.headers)
|
||||
|
||||
if _is_streaming_response(response) is True:
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from io import BytesIO
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastapi import Request, UploadFile
|
||||
from fastapi.testclient import TestClient
|
||||
from starlette.datastructures import Headers
|
||||
from starlette.datastructures import UploadFile as StarletteUploadFile
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||
HttpPassThroughEndpointHelpers,
|
||||
)
|
||||
|
||||
|
||||
# Test is_multipart
|
||||
def test_is_multipart():
|
||||
# Test with multipart content type
|
||||
request = MagicMock(spec=Request)
|
||||
request.headers = Headers({"content-type": "multipart/form-data; boundary=123"})
|
||||
assert HttpPassThroughEndpointHelpers.is_multipart(request) is True
|
||||
|
||||
# Test with non-multipart content type
|
||||
request.headers = Headers({"content-type": "application/json"})
|
||||
assert HttpPassThroughEndpointHelpers.is_multipart(request) is False
|
||||
|
||||
# Test with no content type
|
||||
request.headers = Headers({})
|
||||
assert HttpPassThroughEndpointHelpers.is_multipart(request) is False
|
||||
|
||||
|
||||
# Test _build_request_files_from_upload_file
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_request_files_from_upload_file():
|
||||
# Test with FastAPI UploadFile
|
||||
file_content = b"test content"
|
||||
file = BytesIO(file_content)
|
||||
# Create SpooledTemporaryFile with content type headers
|
||||
headers = {"content-type": "text/plain"}
|
||||
upload_file = UploadFile(file=file, filename="test.txt", headers=headers)
|
||||
upload_file.read = AsyncMock(return_value=file_content)
|
||||
|
||||
result = await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file(
|
||||
upload_file
|
||||
)
|
||||
assert result == ("test.txt", file_content, "text/plain")
|
||||
|
||||
# Test with Starlette UploadFile
|
||||
file2 = BytesIO(file_content)
|
||||
starlette_file = StarletteUploadFile(
|
||||
file=file2,
|
||||
filename="test2.txt",
|
||||
headers=Headers({"content-type": "text/plain"}),
|
||||
)
|
||||
starlette_file.read = AsyncMock(return_value=file_content)
|
||||
|
||||
result = await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file(
|
||||
starlette_file
|
||||
)
|
||||
assert result == ("test2.txt", file_content, "text/plain")
|
||||
|
||||
|
||||
# Test make_multipart_http_request
|
||||
@pytest.mark.asyncio
|
||||
async def test_make_multipart_http_request():
|
||||
# Mock request with file and form field
|
||||
request = MagicMock(spec=Request)
|
||||
request.method = "POST"
|
||||
|
||||
# Mock form data
|
||||
file_content = b"test file content"
|
||||
file = BytesIO(file_content)
|
||||
# Create SpooledTemporaryFile with content type headers
|
||||
headers = {"content-type": "text/plain"}
|
||||
upload_file = UploadFile(file=file, filename="test.txt", headers=headers)
|
||||
upload_file.read = AsyncMock(return_value=file_content)
|
||||
|
||||
form_data = {"file": upload_file, "text_field": "test value"}
|
||||
request.form = AsyncMock(return_value=form_data)
|
||||
|
||||
# Mock httpx client
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {}
|
||||
|
||||
async_client = MagicMock()
|
||||
async_client.request = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Test the function
|
||||
response = await HttpPassThroughEndpointHelpers.make_multipart_http_request(
|
||||
request=request,
|
||||
async_client=async_client,
|
||||
url=httpx.URL("http://test.com"),
|
||||
headers={},
|
||||
requested_query_params=None,
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert response == mock_response
|
||||
|
||||
# Verify the client call
|
||||
async_client.request.assert_called_once()
|
||||
call_args = async_client.request.call_args[1]
|
||||
|
||||
assert call_args["method"] == "POST"
|
||||
assert str(call_args["url"]) == "http://test.com"
|
||||
assert isinstance(call_args["files"], dict)
|
||||
assert isinstance(call_args["data"], dict)
|
||||
assert call_args["data"]["text_field"] == "test value"
|
|
@ -2,14 +2,31 @@ import pytest
|
|||
import openai
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import tempfile
|
||||
from typing_extensions import override
|
||||
from openai import AssistantEventHandler
|
||||
|
||||
|
||||
client = openai.OpenAI(base_url="http://0.0.0.0:4000/openai", api_key="sk-1234")
|
||||
|
||||
def test_pass_through_file_operations():
|
||||
# Create a temporary file
|
||||
with tempfile.NamedTemporaryFile(mode='w+', suffix='.txt', delete=False) as temp_file:
|
||||
temp_file.write("This is a test file for the OpenAI Assistants API.")
|
||||
temp_file.flush()
|
||||
|
||||
# create a file
|
||||
file = client.files.create(
|
||||
file=open(temp_file.name, "rb"),
|
||||
purpose="assistants",
|
||||
)
|
||||
print("file created", file)
|
||||
|
||||
# delete the file
|
||||
delete_file = client.files.delete(file.id)
|
||||
print("file deleted", delete_file)
|
||||
|
||||
def test_openai_assistants_e2e_operations():
|
||||
|
||||
assistant = client.beta.assistants.create(
|
||||
name="Math Tutor",
|
||||
instructions="You are a personal math tutor. Write and run code to answer math questions.",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue