(Bug fix) - allow using Assistants GET, DELETE on /openai pass through routes (#8818)

* test_openai_assistants_e2e_operations

* test openai assistants pass through

* fix GET request on pass through handler

* _make_non_streaming_http_request

* _is_assistants_api_request

* test_openai_assistants_e2e_operations

* test_openai_assistants_e2e_operations

* openai_proxy_route

* docs openai pass through

* docs openai pass through

* docs openai pass through

* test pass through handler

* Potential fix for code scanning alert no. 2240: Incomplete URL substring sanitization

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>

---------

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
This commit is contained in:
Ishaan Jaff 2025-02-25 19:19:00 -08:00 committed by GitHub
parent 142276b468
commit 11fd5094c7
8 changed files with 572 additions and 84 deletions

View file

@ -6,6 +6,7 @@ from datetime import datetime
from typing import List, Optional
import httpx
from urllib.parse import urlparse
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi.responses import StreamingResponse
@ -259,48 +260,82 @@ async def chat_completion_pass_through_endpoint( # noqa: PLR0915
)
def forward_headers_from_request(
request: Request,
headers: dict,
forward_headers: Optional[bool] = False,
):
"""
Helper to forward headers from original request
"""
if forward_headers is True:
request_headers = dict(request.headers)
class HttpPassThroughEndpointHelpers:
@staticmethod
def forward_headers_from_request(
request: Request,
headers: dict,
forward_headers: Optional[bool] = False,
):
"""
Helper to forward headers from original request
"""
if forward_headers is True:
request_headers = dict(request.headers)
# Header We Should NOT forward
request_headers.pop("content-length", None)
request_headers.pop("host", None)
# Header We Should NOT forward
request_headers.pop("content-length", None)
request_headers.pop("host", None)
# Combine request headers with custom headers
headers = {**request_headers, **headers}
return headers
# Combine request headers with custom headers
headers = {**request_headers, **headers}
return headers
@staticmethod
def get_response_headers(
headers: httpx.Headers, litellm_call_id: Optional[str] = None
) -> dict:
excluded_headers = {"transfer-encoding", "content-encoding"}
def get_response_headers(
headers: httpx.Headers, litellm_call_id: Optional[str] = None
) -> dict:
excluded_headers = {"transfer-encoding", "content-encoding"}
return_headers = {
key: value
for key, value in headers.items()
if key.lower() not in excluded_headers
}
if litellm_call_id:
return_headers["x-litellm-call-id"] = litellm_call_id
return_headers = {
key: value
for key, value in headers.items()
if key.lower() not in excluded_headers
}
if litellm_call_id:
return_headers["x-litellm-call-id"] = litellm_call_id
return return_headers
return return_headers
@staticmethod
def get_endpoint_type(url: str) -> EndpointType:
parsed_url = urlparse(url)
if ("generateContent") in url or ("streamGenerateContent") in url:
return EndpointType.VERTEX_AI
elif parsed_url.hostname == "api.anthropic.com":
return EndpointType.ANTHROPIC
return EndpointType.GENERIC
@staticmethod
async def _make_non_streaming_http_request(
request: Request,
async_client: httpx.AsyncClient,
url: str,
headers: dict,
requested_query_params: Optional[dict] = None,
custom_body: Optional[dict] = None,
) -> httpx.Response:
"""
Make a non-streaming HTTP request
def get_endpoint_type(url: str) -> EndpointType:
if ("generateContent") in url or ("streamGenerateContent") in url:
return EndpointType.VERTEX_AI
elif ("api.anthropic.com") in url:
return EndpointType.ANTHROPIC
return EndpointType.GENERIC
If request is GET, don't include a JSON body
"""
if request.method == "GET":
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
params=requested_query_params,
)
else:
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
params=requested_query_params,
json=custom_body,
)
return response
async def pass_through_request( # noqa: PLR0915
@ -321,11 +356,13 @@ async def pass_through_request( # noqa: PLR0915
url = httpx.URL(target)
headers = custom_headers
headers = forward_headers_from_request(
headers = HttpPassThroughEndpointHelpers.forward_headers_from_request(
request=request, headers=headers, forward_headers=forward_headers
)
endpoint_type: EndpointType = get_endpoint_type(str(url))
endpoint_type: EndpointType = HttpPassThroughEndpointHelpers.get_endpoint_type(
str(url)
)
_parsed_body = None
if custom_body:
@ -442,7 +479,7 @@ async def pass_through_request( # noqa: PLR0915
passthrough_success_handler_obj=pass_through_endpoint_logging,
url_route=str(url),
),
headers=get_response_headers(
headers=HttpPassThroughEndpointHelpers.get_response_headers(
headers=response.headers,
litellm_call_id=litellm_call_id,
),
@ -457,13 +494,21 @@ async def pass_through_request( # noqa: PLR0915
)
verbose_proxy_logger.debug("request body: {}".format(_parsed_body))
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
params=requested_query_params,
json=_parsed_body,
)
if request.method == "GET":
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
params=requested_query_params,
)
else:
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
params=requested_query_params,
json=_parsed_body,
)
verbose_proxy_logger.debug("response.headers= %s", response.headers)
@ -485,7 +530,7 @@ async def pass_through_request( # noqa: PLR0915
passthrough_success_handler_obj=pass_through_endpoint_logging,
url_route=str(url),
),
headers=get_response_headers(
headers=HttpPassThroughEndpointHelpers.get_response_headers(
headers=response.headers,
litellm_call_id=litellm_call_id,
),
@ -525,7 +570,7 @@ async def pass_through_request( # noqa: PLR0915
return Response(
content=content,
status_code=response.status_code,
headers=get_response_headers(
headers=HttpPassThroughEndpointHelpers.get_response_headers(
headers=response.headers,
litellm_call_id=litellm_call_id,
),