mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
(Bug fix) - allow using Assistants GET, DELETE on /openai
pass through routes (#8818)
* test_openai_assistants_e2e_operations * test openai assistants pass through * fix GET request on pass through handler * _make_non_streaming_http_request * _is_assistants_api_request * test_openai_assistants_e2e_operations * test_openai_assistants_e2e_operations * openai_proxy_route * docs openai pass through * docs openai pass through * docs openai pass through * test pass through handler * Potential fix for code scanning alert no. 2240: Incomplete URL substring sanitization Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> --------- Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
This commit is contained in:
parent
142276b468
commit
11fd5094c7
8 changed files with 572 additions and 84 deletions
|
@ -6,6 +6,7 @@ from datetime import datetime
|
|||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
from urllib.parse import urlparse
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
|
@ -259,48 +260,82 @@ async def chat_completion_pass_through_endpoint( # noqa: PLR0915
|
|||
)
|
||||
|
||||
|
||||
def forward_headers_from_request(
|
||||
request: Request,
|
||||
headers: dict,
|
||||
forward_headers: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
Helper to forward headers from original request
|
||||
"""
|
||||
if forward_headers is True:
|
||||
request_headers = dict(request.headers)
|
||||
class HttpPassThroughEndpointHelpers:
|
||||
@staticmethod
|
||||
def forward_headers_from_request(
|
||||
request: Request,
|
||||
headers: dict,
|
||||
forward_headers: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
Helper to forward headers from original request
|
||||
"""
|
||||
if forward_headers is True:
|
||||
request_headers = dict(request.headers)
|
||||
|
||||
# Header We Should NOT forward
|
||||
request_headers.pop("content-length", None)
|
||||
request_headers.pop("host", None)
|
||||
# Header We Should NOT forward
|
||||
request_headers.pop("content-length", None)
|
||||
request_headers.pop("host", None)
|
||||
|
||||
# Combine request headers with custom headers
|
||||
headers = {**request_headers, **headers}
|
||||
return headers
|
||||
# Combine request headers with custom headers
|
||||
headers = {**request_headers, **headers}
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
def get_response_headers(
|
||||
headers: httpx.Headers, litellm_call_id: Optional[str] = None
|
||||
) -> dict:
|
||||
excluded_headers = {"transfer-encoding", "content-encoding"}
|
||||
|
||||
def get_response_headers(
|
||||
headers: httpx.Headers, litellm_call_id: Optional[str] = None
|
||||
) -> dict:
|
||||
excluded_headers = {"transfer-encoding", "content-encoding"}
|
||||
return_headers = {
|
||||
key: value
|
||||
for key, value in headers.items()
|
||||
if key.lower() not in excluded_headers
|
||||
}
|
||||
if litellm_call_id:
|
||||
return_headers["x-litellm-call-id"] = litellm_call_id
|
||||
|
||||
return_headers = {
|
||||
key: value
|
||||
for key, value in headers.items()
|
||||
if key.lower() not in excluded_headers
|
||||
}
|
||||
if litellm_call_id:
|
||||
return_headers["x-litellm-call-id"] = litellm_call_id
|
||||
return return_headers
|
||||
|
||||
return return_headers
|
||||
@staticmethod
|
||||
def get_endpoint_type(url: str) -> EndpointType:
|
||||
parsed_url = urlparse(url)
|
||||
if ("generateContent") in url or ("streamGenerateContent") in url:
|
||||
return EndpointType.VERTEX_AI
|
||||
elif parsed_url.hostname == "api.anthropic.com":
|
||||
return EndpointType.ANTHROPIC
|
||||
return EndpointType.GENERIC
|
||||
|
||||
@staticmethod
|
||||
async def _make_non_streaming_http_request(
|
||||
request: Request,
|
||||
async_client: httpx.AsyncClient,
|
||||
url: str,
|
||||
headers: dict,
|
||||
requested_query_params: Optional[dict] = None,
|
||||
custom_body: Optional[dict] = None,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Make a non-streaming HTTP request
|
||||
|
||||
def get_endpoint_type(url: str) -> EndpointType:
|
||||
if ("generateContent") in url or ("streamGenerateContent") in url:
|
||||
return EndpointType.VERTEX_AI
|
||||
elif ("api.anthropic.com") in url:
|
||||
return EndpointType.ANTHROPIC
|
||||
return EndpointType.GENERIC
|
||||
If request is GET, don't include a JSON body
|
||||
"""
|
||||
if request.method == "GET":
|
||||
response = await async_client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=requested_query_params,
|
||||
)
|
||||
else:
|
||||
response = await async_client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=requested_query_params,
|
||||
json=custom_body,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
async def pass_through_request( # noqa: PLR0915
|
||||
|
@ -321,11 +356,13 @@ async def pass_through_request( # noqa: PLR0915
|
|||
|
||||
url = httpx.URL(target)
|
||||
headers = custom_headers
|
||||
headers = forward_headers_from_request(
|
||||
headers = HttpPassThroughEndpointHelpers.forward_headers_from_request(
|
||||
request=request, headers=headers, forward_headers=forward_headers
|
||||
)
|
||||
|
||||
endpoint_type: EndpointType = get_endpoint_type(str(url))
|
||||
endpoint_type: EndpointType = HttpPassThroughEndpointHelpers.get_endpoint_type(
|
||||
str(url)
|
||||
)
|
||||
|
||||
_parsed_body = None
|
||||
if custom_body:
|
||||
|
@ -442,7 +479,7 @@ async def pass_through_request( # noqa: PLR0915
|
|||
passthrough_success_handler_obj=pass_through_endpoint_logging,
|
||||
url_route=str(url),
|
||||
),
|
||||
headers=get_response_headers(
|
||||
headers=HttpPassThroughEndpointHelpers.get_response_headers(
|
||||
headers=response.headers,
|
||||
litellm_call_id=litellm_call_id,
|
||||
),
|
||||
|
@ -457,13 +494,21 @@ async def pass_through_request( # noqa: PLR0915
|
|||
)
|
||||
verbose_proxy_logger.debug("request body: {}".format(_parsed_body))
|
||||
|
||||
response = await async_client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=requested_query_params,
|
||||
json=_parsed_body,
|
||||
)
|
||||
if request.method == "GET":
|
||||
response = await async_client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=requested_query_params,
|
||||
)
|
||||
else:
|
||||
response = await async_client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=requested_query_params,
|
||||
json=_parsed_body,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("response.headers= %s", response.headers)
|
||||
|
||||
|
@ -485,7 +530,7 @@ async def pass_through_request( # noqa: PLR0915
|
|||
passthrough_success_handler_obj=pass_through_endpoint_logging,
|
||||
url_route=str(url),
|
||||
),
|
||||
headers=get_response_headers(
|
||||
headers=HttpPassThroughEndpointHelpers.get_response_headers(
|
||||
headers=response.headers,
|
||||
litellm_call_id=litellm_call_id,
|
||||
),
|
||||
|
@ -525,7 +570,7 @@ async def pass_through_request( # noqa: PLR0915
|
|||
return Response(
|
||||
content=content,
|
||||
status_code=response.status_code,
|
||||
headers=get_response_headers(
|
||||
headers=HttpPassThroughEndpointHelpers.get_response_headers(
|
||||
headers=response.headers,
|
||||
litellm_call_id=litellm_call_id,
|
||||
),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue