mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
[Bug Fix] Add support for UploadFile on LLM Pass through endpoints (OpenAI, Azure etc) (#9853)
* http passthrough file handling * fix make_multipart_http_request * test_pass_through_file_operations * unit tests for file handling
This commit is contained in:
parent
6ba3c4a4f8
commit
08a3620414
3 changed files with 241 additions and 19 deletions
|
@ -4,16 +4,26 @@ import json
|
||||||
import uuid
|
import uuid
|
||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
from urllib.parse import parse_qs, urlencode, urlparse
|
from urllib.parse import parse_qs, urlencode, urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
from fastapi import (
|
||||||
|
APIRouter,
|
||||||
|
Depends,
|
||||||
|
HTTPException,
|
||||||
|
Request,
|
||||||
|
Response,
|
||||||
|
UploadFile,
|
||||||
|
status,
|
||||||
|
)
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
from starlette.datastructures import UploadFile as StarletteUploadFile
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||||
from litellm.proxy._types import (
|
from litellm.proxy._types import (
|
||||||
ConfigFieldInfo,
|
ConfigFieldInfo,
|
||||||
|
@ -358,6 +368,92 @@ class HttpPassThroughEndpointHelpers:
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def non_streaming_http_request_handler(
|
||||||
|
request: Request,
|
||||||
|
async_client: httpx.AsyncClient,
|
||||||
|
url: httpx.URL,
|
||||||
|
headers: dict,
|
||||||
|
requested_query_params: Optional[dict] = None,
|
||||||
|
_parsed_body: Optional[dict] = None,
|
||||||
|
) -> httpx.Response:
|
||||||
|
"""
|
||||||
|
Handle non-streaming HTTP requests
|
||||||
|
|
||||||
|
Handles special cases when GET requests, multipart/form-data requests, and generic httpx requests
|
||||||
|
"""
|
||||||
|
if request.method == "GET":
|
||||||
|
response = await async_client.request(
|
||||||
|
method=request.method,
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
params=requested_query_params,
|
||||||
|
)
|
||||||
|
elif HttpPassThroughEndpointHelpers.is_multipart(request) is True:
|
||||||
|
return await HttpPassThroughEndpointHelpers.make_multipart_http_request(
|
||||||
|
request=request,
|
||||||
|
async_client=async_client,
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
requested_query_params=requested_query_params,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Generic httpx method
|
||||||
|
response = await async_client.request(
|
||||||
|
method=request.method,
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
params=requested_query_params,
|
||||||
|
json=_parsed_body,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_multipart(request: Request) -> bool:
|
||||||
|
"""Check if the request is a multipart/form-data request"""
|
||||||
|
return "multipart/form-data" in request.headers.get("content-type", "")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _build_request_files_from_upload_file(
|
||||||
|
upload_file: Union[UploadFile, StarletteUploadFile],
|
||||||
|
) -> Tuple[Optional[str], bytes, Optional[str]]:
|
||||||
|
"""Build a request files dict from an UploadFile object"""
|
||||||
|
file_content = await upload_file.read()
|
||||||
|
return (upload_file.filename, file_content, upload_file.content_type)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def make_multipart_http_request(
|
||||||
|
request: Request,
|
||||||
|
async_client: httpx.AsyncClient,
|
||||||
|
url: httpx.URL,
|
||||||
|
headers: dict,
|
||||||
|
requested_query_params: Optional[dict] = None,
|
||||||
|
) -> httpx.Response:
|
||||||
|
"""Process multipart/form-data requests, handling both files and form fields"""
|
||||||
|
form_data = await request.form()
|
||||||
|
files = {}
|
||||||
|
form_data_dict = {}
|
||||||
|
|
||||||
|
for field_name, field_value in form_data.items():
|
||||||
|
if isinstance(field_value, (StarletteUploadFile, UploadFile)):
|
||||||
|
files[field_name] = (
|
||||||
|
await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file(
|
||||||
|
upload_file=field_value
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
form_data_dict[field_name] = field_value
|
||||||
|
|
||||||
|
response = await async_client.request(
|
||||||
|
method=request.method,
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
params=requested_query_params,
|
||||||
|
files=files,
|
||||||
|
data=form_data_dict,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
async def pass_through_request( # noqa: PLR0915
|
async def pass_through_request( # noqa: PLR0915
|
||||||
request: Request,
|
request: Request,
|
||||||
|
@ -424,7 +520,7 @@ async def pass_through_request( # noqa: PLR0915
|
||||||
start_time = datetime.now()
|
start_time = datetime.now()
|
||||||
logging_obj = Logging(
|
logging_obj = Logging(
|
||||||
model="unknown",
|
model="unknown",
|
||||||
messages=[{"role": "user", "content": json.dumps(_parsed_body)}],
|
messages=[{"role": "user", "content": safe_dumps(_parsed_body)}],
|
||||||
stream=False,
|
stream=False,
|
||||||
call_type="pass_through_endpoint",
|
call_type="pass_through_endpoint",
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
|
@ -453,7 +549,6 @@ async def pass_through_request( # noqa: PLR0915
|
||||||
logging_obj.model_call_details["litellm_call_id"] = litellm_call_id
|
logging_obj.model_call_details["litellm_call_id"] = litellm_call_id
|
||||||
|
|
||||||
# combine url with query params for logging
|
# combine url with query params for logging
|
||||||
|
|
||||||
requested_query_params: Optional[dict] = (
|
requested_query_params: Optional[dict] = (
|
||||||
query_params or request.query_params.__dict__
|
query_params or request.query_params.__dict__
|
||||||
)
|
)
|
||||||
|
@ -474,7 +569,7 @@ async def pass_through_request( # noqa: PLR0915
|
||||||
logging_url = str(url) + "?" + requested_query_params_str
|
logging_url = str(url) + "?" + requested_query_params_str
|
||||||
|
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=[{"role": "user", "content": json.dumps(_parsed_body)}],
|
input=[{"role": "user", "content": safe_dumps(_parsed_body)}],
|
||||||
api_key="",
|
api_key="",
|
||||||
additional_args={
|
additional_args={
|
||||||
"complete_input_dict": _parsed_body,
|
"complete_input_dict": _parsed_body,
|
||||||
|
@ -525,22 +620,16 @@ async def pass_through_request( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.debug("request body: {}".format(_parsed_body))
|
verbose_proxy_logger.debug("request body: {}".format(_parsed_body))
|
||||||
|
|
||||||
if request.method == "GET":
|
response = (
|
||||||
response = await async_client.request(
|
await HttpPassThroughEndpointHelpers.non_streaming_http_request_handler(
|
||||||
method=request.method,
|
request=request,
|
||||||
|
async_client=async_client,
|
||||||
url=url,
|
url=url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
params=requested_query_params,
|
requested_query_params=requested_query_params,
|
||||||
|
_parsed_body=_parsed_body,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
response = await async_client.request(
|
|
||||||
method=request.method,
|
|
||||||
url=url,
|
|
||||||
headers=headers,
|
|
||||||
params=requested_query_params,
|
|
||||||
json=_parsed_body,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
verbose_proxy_logger.debug("response.headers= %s", response.headers)
|
verbose_proxy_logger.debug("response.headers= %s", response.headers)
|
||||||
|
|
||||||
if _is_streaming_response(response) is True:
|
if _is_streaming_response(response) is True:
|
||||||
|
|
|
@ -0,0 +1,116 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from io import BytesIO
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
from fastapi import Request, UploadFile
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from starlette.datastructures import Headers
|
||||||
|
from starlette.datastructures import UploadFile as StarletteUploadFile
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||||
|
HttpPassThroughEndpointHelpers,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Test is_multipart
|
||||||
|
def test_is_multipart():
|
||||||
|
# Test with multipart content type
|
||||||
|
request = MagicMock(spec=Request)
|
||||||
|
request.headers = Headers({"content-type": "multipart/form-data; boundary=123"})
|
||||||
|
assert HttpPassThroughEndpointHelpers.is_multipart(request) is True
|
||||||
|
|
||||||
|
# Test with non-multipart content type
|
||||||
|
request.headers = Headers({"content-type": "application/json"})
|
||||||
|
assert HttpPassThroughEndpointHelpers.is_multipart(request) is False
|
||||||
|
|
||||||
|
# Test with no content type
|
||||||
|
request.headers = Headers({})
|
||||||
|
assert HttpPassThroughEndpointHelpers.is_multipart(request) is False
|
||||||
|
|
||||||
|
|
||||||
|
# Test _build_request_files_from_upload_file
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_build_request_files_from_upload_file():
|
||||||
|
# Test with FastAPI UploadFile
|
||||||
|
file_content = b"test content"
|
||||||
|
file = BytesIO(file_content)
|
||||||
|
# Create SpooledTemporaryFile with content type headers
|
||||||
|
headers = {"content-type": "text/plain"}
|
||||||
|
upload_file = UploadFile(file=file, filename="test.txt", headers=headers)
|
||||||
|
upload_file.read = AsyncMock(return_value=file_content)
|
||||||
|
|
||||||
|
result = await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file(
|
||||||
|
upload_file
|
||||||
|
)
|
||||||
|
assert result == ("test.txt", file_content, "text/plain")
|
||||||
|
|
||||||
|
# Test with Starlette UploadFile
|
||||||
|
file2 = BytesIO(file_content)
|
||||||
|
starlette_file = StarletteUploadFile(
|
||||||
|
file=file2,
|
||||||
|
filename="test2.txt",
|
||||||
|
headers=Headers({"content-type": "text/plain"}),
|
||||||
|
)
|
||||||
|
starlette_file.read = AsyncMock(return_value=file_content)
|
||||||
|
|
||||||
|
result = await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file(
|
||||||
|
starlette_file
|
||||||
|
)
|
||||||
|
assert result == ("test2.txt", file_content, "text/plain")
|
||||||
|
|
||||||
|
|
||||||
|
# Test make_multipart_http_request
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_make_multipart_http_request():
|
||||||
|
# Mock request with file and form field
|
||||||
|
request = MagicMock(spec=Request)
|
||||||
|
request.method = "POST"
|
||||||
|
|
||||||
|
# Mock form data
|
||||||
|
file_content = b"test file content"
|
||||||
|
file = BytesIO(file_content)
|
||||||
|
# Create SpooledTemporaryFile with content type headers
|
||||||
|
headers = {"content-type": "text/plain"}
|
||||||
|
upload_file = UploadFile(file=file, filename="test.txt", headers=headers)
|
||||||
|
upload_file.read = AsyncMock(return_value=file_content)
|
||||||
|
|
||||||
|
form_data = {"file": upload_file, "text_field": "test value"}
|
||||||
|
request.form = AsyncMock(return_value=form_data)
|
||||||
|
|
||||||
|
# Mock httpx client
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.headers = {}
|
||||||
|
|
||||||
|
async_client = MagicMock()
|
||||||
|
async_client.request = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
# Test the function
|
||||||
|
response = await HttpPassThroughEndpointHelpers.make_multipart_http_request(
|
||||||
|
request=request,
|
||||||
|
async_client=async_client,
|
||||||
|
url=httpx.URL("http://test.com"),
|
||||||
|
headers={},
|
||||||
|
requested_query_params=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the response
|
||||||
|
assert response == mock_response
|
||||||
|
|
||||||
|
# Verify the client call
|
||||||
|
async_client.request.assert_called_once()
|
||||||
|
call_args = async_client.request.call_args[1]
|
||||||
|
|
||||||
|
assert call_args["method"] == "POST"
|
||||||
|
assert str(call_args["url"]) == "http://test.com"
|
||||||
|
assert isinstance(call_args["files"], dict)
|
||||||
|
assert isinstance(call_args["data"], dict)
|
||||||
|
assert call_args["data"]["text_field"] == "test value"
|
|
@ -2,14 +2,31 @@ import pytest
|
||||||
import openai
|
import openai
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import tempfile
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from openai import AssistantEventHandler
|
from openai import AssistantEventHandler
|
||||||
|
|
||||||
|
|
||||||
client = openai.OpenAI(base_url="http://0.0.0.0:4000/openai", api_key="sk-1234")
|
client = openai.OpenAI(base_url="http://0.0.0.0:4000/openai", api_key="sk-1234")
|
||||||
|
|
||||||
|
def test_pass_through_file_operations():
|
||||||
|
# Create a temporary file
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w+', suffix='.txt', delete=False) as temp_file:
|
||||||
|
temp_file.write("This is a test file for the OpenAI Assistants API.")
|
||||||
|
temp_file.flush()
|
||||||
|
|
||||||
|
# create a file
|
||||||
|
file = client.files.create(
|
||||||
|
file=open(temp_file.name, "rb"),
|
||||||
|
purpose="assistants",
|
||||||
|
)
|
||||||
|
print("file created", file)
|
||||||
|
|
||||||
|
# delete the file
|
||||||
|
delete_file = client.files.delete(file.id)
|
||||||
|
print("file deleted", delete_file)
|
||||||
|
|
||||||
def test_openai_assistants_e2e_operations():
|
def test_openai_assistants_e2e_operations():
|
||||||
|
|
||||||
assistant = client.beta.assistants.create(
|
assistant = client.beta.assistants.create(
|
||||||
name="Math Tutor",
|
name="Math Tutor",
|
||||||
instructions="You are a personal math tutor. Write and run code to answer math questions.",
|
instructions="You are a personal math tutor. Write and run code to answer math questions.",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue