diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index a6b1b3e614..563d0cb543 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -4,16 +4,26 @@ import json import uuid from base64 import b64encode from datetime import datetime -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union from urllib.parse import parse_qs, urlencode, urlparse import httpx -from fastapi import APIRouter, Depends, HTTPException, Request, Response, status +from fastapi import ( + APIRouter, + Depends, + HTTPException, + Request, + Response, + UploadFile, + status, +) from fastapi.responses import StreamingResponse +from starlette.datastructures import UploadFile as StarletteUploadFile import litellm from litellm._logging import verbose_proxy_logger from litellm.integrations.custom_logger import CustomLogger +from litellm.litellm_core_utils.safe_json_dumps import safe_dumps from litellm.llms.custom_httpx.http_handler import get_async_httpx_client from litellm.proxy._types import ( ConfigFieldInfo, @@ -358,6 +368,92 @@ class HttpPassThroughEndpointHelpers: ) return response + @staticmethod + async def non_streaming_http_request_handler( + request: Request, + async_client: httpx.AsyncClient, + url: httpx.URL, + headers: dict, + requested_query_params: Optional[dict] = None, + _parsed_body: Optional[dict] = None, + ) -> httpx.Response: + """ + Handle non-streaming HTTP requests + + Handles special cases when GET requests, multipart/form-data requests, and generic httpx requests + """ + if request.method == "GET": + response = await async_client.request( + method=request.method, + url=url, + headers=headers, + params=requested_query_params, + ) + elif HttpPassThroughEndpointHelpers.is_multipart(request) is True: + return await HttpPassThroughEndpointHelpers.make_multipart_http_request( + request=request, + async_client=async_client, + url=url, + headers=headers, + requested_query_params=requested_query_params, + ) + else: + # Generic httpx method + response = await async_client.request( + method=request.method, + url=url, + headers=headers, + params=requested_query_params, + json=_parsed_body, + ) + return response + + @staticmethod + def is_multipart(request: Request) -> bool: + """Check if the request is a multipart/form-data request""" + return "multipart/form-data" in request.headers.get("content-type", "") + + @staticmethod + async def _build_request_files_from_upload_file( + upload_file: Union[UploadFile, StarletteUploadFile], + ) -> Tuple[Optional[str], bytes, Optional[str]]: + """Build a request files dict from an UploadFile object""" + file_content = await upload_file.read() + return (upload_file.filename, file_content, upload_file.content_type) + + @staticmethod + async def make_multipart_http_request( + request: Request, + async_client: httpx.AsyncClient, + url: httpx.URL, + headers: dict, + requested_query_params: Optional[dict] = None, + ) -> httpx.Response: + """Process multipart/form-data requests, handling both files and form fields""" + form_data = await request.form() + files = {} + form_data_dict = {} + + for field_name, field_value in form_data.items(): + if isinstance(field_value, (StarletteUploadFile, UploadFile)): + files[field_name] = ( + await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file( + upload_file=field_value + ) + ) + else: + form_data_dict[field_name] = field_value + + response = await async_client.request( + method=request.method, + url=url, + headers=headers, + params=requested_query_params, + files=files, + data=form_data_dict, + ) + return response + async def pass_through_request( # noqa: PLR0915 request: Request, @@ -424,7 +520,7 @@ async def pass_through_request( # noqa: PLR0915 start_time = datetime.now() logging_obj = Logging( model="unknown", - messages=[{"role": "user", "content": json.dumps(_parsed_body)}], + messages=[{"role": "user", "content": safe_dumps(_parsed_body)}], stream=False, call_type="pass_through_endpoint", start_time=start_time, @@ -453,7 +549,6 @@ async def pass_through_request( # noqa: PLR0915 logging_obj.model_call_details["litellm_call_id"] = litellm_call_id # combine url with query params for logging - requested_query_params: Optional[dict] = ( query_params or request.query_params.__dict__ ) @@ -474,7 +569,7 @@ async def pass_through_request( # noqa: PLR0915 logging_url = str(url) + "?" + requested_query_params_str logging_obj.pre_call( - input=[{"role": "user", "content": json.dumps(_parsed_body)}], + input=[{"role": "user", "content": safe_dumps(_parsed_body)}], api_key="", additional_args={ "complete_input_dict": _parsed_body, @@ -525,22 +620,16 @@ async def pass_through_request( # noqa: PLR0915 ) verbose_proxy_logger.debug("request body: {}".format(_parsed_body)) - if request.method == "GET": - response = await async_client.request( - method=request.method, + response = ( + await HttpPassThroughEndpointHelpers.non_streaming_http_request_handler( + request=request, + async_client=async_client, url=url, headers=headers, - params=requested_query_params, + requested_query_params=requested_query_params, + _parsed_body=_parsed_body, ) - else: - response = await async_client.request( - method=request.method, - url=url, - headers=headers, - params=requested_query_params, - json=_parsed_body, - ) - + ) verbose_proxy_logger.debug("response.headers= %s", response.headers) if _is_streaming_response(response) is True: diff --git a/tests/litellm/proxy/pass_through_endpoints/test_pass_through_endpoints.py b/tests/litellm/proxy/pass_through_endpoints/test_pass_through_endpoints.py new file mode 100644 index 0000000000..43d4dd9cd8 --- /dev/null +++ b/tests/litellm/proxy/pass_through_endpoints/test_pass_through_endpoints.py @@ -0,0 +1,116 @@ +import json +import os +import sys +from io import BytesIO +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest +from fastapi import Request, UploadFile +from fastapi.testclient import TestClient +from starlette.datastructures import Headers +from starlette.datastructures import UploadFile as StarletteUploadFile + +sys.path.insert( + 0, os.path.abspath("../../..") +) # Adds the parent directory to the system path + +from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( + HttpPassThroughEndpointHelpers, +) + + +# Test is_multipart +def test_is_multipart(): + # Test with multipart content type + request = MagicMock(spec=Request) + request.headers = Headers({"content-type": "multipart/form-data; boundary=123"}) + assert HttpPassThroughEndpointHelpers.is_multipart(request) is True + + # Test with non-multipart content type + request.headers = Headers({"content-type": "application/json"}) + assert HttpPassThroughEndpointHelpers.is_multipart(request) is False + + # Test with no content type + request.headers = Headers({}) + assert HttpPassThroughEndpointHelpers.is_multipart(request) is False + + +# Test _build_request_files_from_upload_file +@pytest.mark.asyncio +async def test_build_request_files_from_upload_file(): + # Test with FastAPI UploadFile + file_content = b"test content" + file = BytesIO(file_content) + # Create SpooledTemporaryFile with content type headers + headers = {"content-type": "text/plain"} + upload_file = UploadFile(file=file, filename="test.txt", headers=headers) + upload_file.read = AsyncMock(return_value=file_content) + + result = await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file( + upload_file + ) + assert result == ("test.txt", file_content, "text/plain") + + # Test with Starlette UploadFile + file2 = BytesIO(file_content) + starlette_file = StarletteUploadFile( + file=file2, + filename="test2.txt", + headers=Headers({"content-type": "text/plain"}), + ) + starlette_file.read = AsyncMock(return_value=file_content) + + result = await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file( + starlette_file + ) + assert result == ("test2.txt", file_content, "text/plain") + + +# Test make_multipart_http_request +@pytest.mark.asyncio +async def test_make_multipart_http_request(): + # Mock request with file and form field + request = MagicMock(spec=Request) + request.method = "POST" + + # Mock form data + file_content = b"test file content" + file = BytesIO(file_content) + # Create SpooledTemporaryFile with content type headers + headers = {"content-type": "text/plain"} + upload_file = UploadFile(file=file, filename="test.txt", headers=headers) + upload_file.read = AsyncMock(return_value=file_content) + + form_data = {"file": upload_file, "text_field": "test value"} + request.form = AsyncMock(return_value=form_data) + + # Mock httpx client + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {} + + async_client = MagicMock() + async_client.request = AsyncMock(return_value=mock_response) + + # Test the function + response = await HttpPassThroughEndpointHelpers.make_multipart_http_request( + request=request, + async_client=async_client, + url=httpx.URL("http://test.com"), + headers={}, + requested_query_params=None, + ) + + # Verify the response + assert response == mock_response + + # Verify the client call + async_client.request.assert_called_once() + call_args = async_client.request.call_args[1] + + assert call_args["method"] == "POST" + assert str(call_args["url"]) == "http://test.com" + assert isinstance(call_args["files"], dict) + assert isinstance(call_args["data"], dict) + assert call_args["data"]["text_field"] == "test value" diff --git a/tests/pass_through_tests/test_openai_assistants_passthrough.py b/tests/pass_through_tests/test_openai_assistants_passthrough.py index 694d3c090e..40361ab39f 100644 --- a/tests/pass_through_tests/test_openai_assistants_passthrough.py +++ b/tests/pass_through_tests/test_openai_assistants_passthrough.py @@ -2,14 +2,31 @@ import pytest import openai import aiohttp import asyncio +import tempfile from typing_extensions import override from openai import AssistantEventHandler + client = openai.OpenAI(base_url="http://0.0.0.0:4000/openai", api_key="sk-1234") +def test_pass_through_file_operations(): + # Create a temporary file + with tempfile.NamedTemporaryFile(mode='w+', suffix='.txt', delete=False) as temp_file: + temp_file.write("This is a test file for the OpenAI Assistants API.") + temp_file.flush() + + # create a file + file = client.files.create( + file=open(temp_file.name, "rb"), + purpose="assistants", + ) + print("file created", file) + + # delete the file + delete_file = client.files.delete(file.id) + print("file deleted", delete_file) def test_openai_assistants_e2e_operations(): - assistant = client.beta.assistants.create( name="Math Tutor", instructions="You are a personal math tutor. Write and run code to answer math questions.",