Merge branch 'main' into stevefarthing/bing-search-pass-thru

This commit is contained in:
Steve Farthing 2025-03-11 08:06:56 -04:00 committed by GitHub
commit b79b126597
741 changed files with 66437 additions and 15378 deletions

View file

@ -4,7 +4,7 @@ import json
from base64 import b64encode
from datetime import datetime
from typing import List, Optional, Union, Dict
from urllib.parse import urlencode, parse_qs
from urllib.parse import urlencode, parse_qs, urlparse
import httpx
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
@ -26,6 +26,7 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.custom_http import httpxSpecialProvider
from litellm.types.utils import StandardLoggingUserAPIKeyMetadata
from .streaming_handler import PassThroughStreamingHandler
from .success_handler import PassThroughEndpointLogging
@ -134,7 +135,7 @@ async def chat_completion_pass_through_endpoint( # noqa: PLR0915
data["model"] = (
general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args
or data["model"] # default passed in http request
or data.get("model", None) # default passed in http request
)
if user_model:
data["model"] = user_model
@ -259,66 +260,98 @@ async def chat_completion_pass_through_endpoint( # noqa: PLR0915
code=getattr(e, "status_code", 500),
)
class HttpPassThroughEndpointHelpers:
@staticmethod
def forward_headers_from_request(
request: Request,
headers: dict,
forward_headers: Optional[bool] = False,
):
"""
Helper to forward headers from original request
"""
if forward_headers is True:
request_headers = dict(request.headers)
def forward_headers_from_request(
request: Request,
headers: dict,
forward_headers: Optional[bool] = False,
):
"""
Helper to forward headers from original request
"""
if forward_headers is True:
request_headers = dict(request.headers)
# Header We Should NOT forward
request_headers.pop("content-length", None)
request_headers.pop("host", None)
# Header We Should NOT forward
request_headers.pop("content-length", None)
request_headers.pop("host", None)
# Combine request headers with custom headers
headers = {**request_headers, **headers}
return headers
# Combine request headers with custom headers
headers = {**request_headers, **headers}
return headers
@staticmethod
def get_response_headers(
headers: httpx.Headers, litellm_call_id: Optional[str] = None
) -> dict:
excluded_headers = {"transfer-encoding", "content-encoding"}
return_headers = {
key: value
for key, value in headers.items()
if key.lower() not in excluded_headers
}
if litellm_call_id:
return_headers["x-litellm-call-id"] = litellm_call_id
def get_response_headers(
headers: httpx.Headers, litellm_call_id: Optional[str] = None
) -> dict:
excluded_headers = {"transfer-encoding", "content-encoding"}
return return_headers
return_headers = {
key: value
for key, value in headers.items()
if key.lower() not in excluded_headers
}
if litellm_call_id:
return_headers["x-litellm-call-id"] = litellm_call_id
@staticmethod
def get_endpoint_type(url: str) -> EndpointType:
parsed_url = urlparse(url)
if ("generateContent") in url or ("streamGenerateContent") in url:
return EndpointType.VERTEX_AI
elif parsed_url.hostname == "api.anthropic.com":
return EndpointType.ANTHROPIC
return EndpointType.GENERIC
return return_headers
@staticmethod
def get_merged_query_parameters(
existing_url: httpx.URL, request_query_params: Dict[str, Union[str, list]]
) -> Dict[str, Union[str, List[str]]]:
# Get the existing query params from the target URL
existing_query_string = existing_url.query.decode("utf-8")
existing_query_params = parse_qs(existing_query_string)
# parse_qs returns a dict where each value is a list, so let's flatten it
existing_query_params = {
k: v[0] if len(v) == 1 else v for k, v in existing_query_params.items()
}
def get_merged_query_parameters(
existing_url: httpx.URL, request_query_params: Dict[str, Union[str, list]]
) -> Dict[str, Union[str, List[str]]]:
# Get the existing query params from the target URL
existing_query_string = existing_url.query.decode("utf-8")
existing_query_params = parse_qs(existing_query_string)
# Merge the query params, giving priority to the existing ones
return {**request_query_params, **existing_query_params}
# parse_qs returns a dict where each value is a list, so let's flatten it
existing_query_params = {
k: v[0] if len(v) == 1 else v for k, v in existing_query_params.items()
}
# Merge the query params, giving priority to the existing ones
return {**request_query_params, **existing_query_params}
def get_endpoint_type(url: str) -> EndpointType:
if ("generateContent") in url or ("streamGenerateContent") in url:
return EndpointType.VERTEX_AI
elif ("api.anthropic.com") in url:
return EndpointType.ANTHROPIC
return EndpointType.GENERIC
@staticmethod
async def _make_non_streaming_http_request(
request: Request,
async_client: httpx.AsyncClient,
url: str,
headers: dict,
requested_query_params: Optional[dict] = None,
custom_body: Optional[dict] = None,
) -> httpx.Response:
"""
Make a non-streaming HTTP request
If request is GET, don't include a JSON body
"""
if request.method == "GET":
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
params=requested_query_params,
)
else:
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
params=requested_query_params,
json=custom_body,
)
return response
async def pass_through_request( # noqa: PLR0915
request: Request,
@ -339,23 +372,27 @@ async def pass_through_request( # noqa: PLR0915
url = httpx.URL(target)
headers = custom_headers
headers = forward_headers_from_request(
headers = HttpPassThroughEndpointHelpers.forward_headers_from_request(
request=request, headers=headers, forward_headers=forward_headers
)
if merge_query_params:
# Create a new URL with the merged query params
url = url.copy_with(
query=urlencode(
get_merged_query_parameters(
HttpPassThroughEndpointHelpers.get_merged_query_parameters(
existing_url=url,
request_query_params=dict(request.query_params),
)
).encode("ascii")
)
endpoint_type: EndpointType = get_endpoint_type(str(url))
endpoint_type: EndpointType = HttpPassThroughEndpointHelpers.get_endpoint_type(
str(url)
)
_parsed_body = None
if custom_body:
@ -472,7 +509,7 @@ async def pass_through_request( # noqa: PLR0915
passthrough_success_handler_obj=pass_through_endpoint_logging,
url_route=str(url),
),
headers=get_response_headers(
headers=HttpPassThroughEndpointHelpers.get_response_headers(
headers=response.headers,
litellm_call_id=litellm_call_id,
),
@ -487,13 +524,21 @@ async def pass_through_request( # noqa: PLR0915
)
verbose_proxy_logger.debug("request body: {}".format(_parsed_body))
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
params=requested_query_params,
json=_parsed_body,
)
if request.method == "GET":
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
params=requested_query_params,
)
else:
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
params=requested_query_params,
json=_parsed_body,
)
verbose_proxy_logger.debug("response.headers= %s", response.headers)
@ -515,7 +560,7 @@ async def pass_through_request( # noqa: PLR0915
passthrough_success_handler_obj=pass_through_endpoint_logging,
url_route=str(url),
),
headers=get_response_headers(
headers=HttpPassThroughEndpointHelpers.get_response_headers(
headers=response.headers,
litellm_call_id=litellm_call_id,
),
@ -555,7 +600,7 @@ async def pass_through_request( # noqa: PLR0915
return Response(
content=content,
status_code=response.status_code,
headers=get_response_headers(
headers=HttpPassThroughEndpointHelpers.get_response_headers(
headers=response.headers,
litellm_call_id=litellm_call_id,
),
@ -592,12 +637,19 @@ def _init_kwargs_for_pass_through_endpoint(
) -> dict:
_parsed_body = _parsed_body or {}
_litellm_metadata: Optional[dict] = _parsed_body.pop("litellm_metadata", None)
_metadata = {
"user_api_key": user_api_key_dict.api_key,
"user_api_key_user_id": user_api_key_dict.user_id,
"user_api_key_team_id": user_api_key_dict.team_id,
"user_api_key_end_user_id": user_api_key_dict.end_user_id,
}
_metadata = dict(
StandardLoggingUserAPIKeyMetadata(
user_api_key_hash=user_api_key_dict.api_key,
user_api_key_alias=user_api_key_dict.key_alias,
user_api_key_user_email=user_api_key_dict.user_email,
user_api_key_user_id=user_api_key_dict.user_id,
user_api_key_team_id=user_api_key_dict.team_id,
user_api_key_org_id=user_api_key_dict.org_id,
user_api_key_team_alias=user_api_key_dict.team_alias,
user_api_key_end_user_id=user_api_key_dict.end_user_id,
)
)
_metadata["user_api_key"] = user_api_key_dict.api_key
if _litellm_metadata:
_metadata.update(_litellm_metadata)
@ -640,7 +692,7 @@ def create_pass_through_route(
# check if target is an adapter.py or a url
import uuid
from litellm.proxy.utils import get_instance_fn
from litellm.proxy.types_utils.utils import get_instance_fn
try:
if isinstance(target, CustomLogger):