mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge branch 'main' into stevefarthing/bing-search-pass-thru
This commit is contained in:
commit
b79b126597
741 changed files with 66437 additions and 15378 deletions
|
@ -4,7 +4,7 @@ import json
|
|||
from base64 import b64encode
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Union, Dict
|
||||
from urllib.parse import urlencode, parse_qs
|
||||
from urllib.parse import urlencode, parse_qs, urlparse
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
|
@ -26,6 +26,7 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
|||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.custom_http import httpxSpecialProvider
|
||||
from litellm.types.utils import StandardLoggingUserAPIKeyMetadata
|
||||
|
||||
from .streaming_handler import PassThroughStreamingHandler
|
||||
from .success_handler import PassThroughEndpointLogging
|
||||
|
@ -134,7 +135,7 @@ async def chat_completion_pass_through_endpoint( # noqa: PLR0915
|
|||
data["model"] = (
|
||||
general_settings.get("completion_model", None) # server default
|
||||
or user_model # model name passed via cli args
|
||||
or data["model"] # default passed in http request
|
||||
or data.get("model", None) # default passed in http request
|
||||
)
|
||||
if user_model:
|
||||
data["model"] = user_model
|
||||
|
@ -259,66 +260,98 @@ async def chat_completion_pass_through_endpoint( # noqa: PLR0915
|
|||
code=getattr(e, "status_code", 500),
|
||||
)
|
||||
|
||||
class HttpPassThroughEndpointHelpers:
|
||||
@staticmethod
|
||||
def forward_headers_from_request(
|
||||
request: Request,
|
||||
headers: dict,
|
||||
forward_headers: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
Helper to forward headers from original request
|
||||
"""
|
||||
if forward_headers is True:
|
||||
request_headers = dict(request.headers)
|
||||
|
||||
def forward_headers_from_request(
|
||||
request: Request,
|
||||
headers: dict,
|
||||
forward_headers: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
Helper to forward headers from original request
|
||||
"""
|
||||
if forward_headers is True:
|
||||
request_headers = dict(request.headers)
|
||||
# Header We Should NOT forward
|
||||
request_headers.pop("content-length", None)
|
||||
request_headers.pop("host", None)
|
||||
|
||||
# Header We Should NOT forward
|
||||
request_headers.pop("content-length", None)
|
||||
request_headers.pop("host", None)
|
||||
# Combine request headers with custom headers
|
||||
headers = {**request_headers, **headers}
|
||||
return headers
|
||||
|
||||
# Combine request headers with custom headers
|
||||
headers = {**request_headers, **headers}
|
||||
return headers
|
||||
@staticmethod
|
||||
def get_response_headers(
|
||||
headers: httpx.Headers, litellm_call_id: Optional[str] = None
|
||||
) -> dict:
|
||||
excluded_headers = {"transfer-encoding", "content-encoding"}
|
||||
|
||||
return_headers = {
|
||||
key: value
|
||||
for key, value in headers.items()
|
||||
if key.lower() not in excluded_headers
|
||||
}
|
||||
if litellm_call_id:
|
||||
return_headers["x-litellm-call-id"] = litellm_call_id
|
||||
|
||||
def get_response_headers(
|
||||
headers: httpx.Headers, litellm_call_id: Optional[str] = None
|
||||
) -> dict:
|
||||
excluded_headers = {"transfer-encoding", "content-encoding"}
|
||||
return return_headers
|
||||
|
||||
return_headers = {
|
||||
key: value
|
||||
for key, value in headers.items()
|
||||
if key.lower() not in excluded_headers
|
||||
}
|
||||
if litellm_call_id:
|
||||
return_headers["x-litellm-call-id"] = litellm_call_id
|
||||
@staticmethod
|
||||
def get_endpoint_type(url: str) -> EndpointType:
|
||||
parsed_url = urlparse(url)
|
||||
if ("generateContent") in url or ("streamGenerateContent") in url:
|
||||
return EndpointType.VERTEX_AI
|
||||
elif parsed_url.hostname == "api.anthropic.com":
|
||||
return EndpointType.ANTHROPIC
|
||||
return EndpointType.GENERIC
|
||||
|
||||
return return_headers
|
||||
@staticmethod
|
||||
def get_merged_query_parameters(
|
||||
existing_url: httpx.URL, request_query_params: Dict[str, Union[str, list]]
|
||||
) -> Dict[str, Union[str, List[str]]]:
|
||||
# Get the existing query params from the target URL
|
||||
existing_query_string = existing_url.query.decode("utf-8")
|
||||
existing_query_params = parse_qs(existing_query_string)
|
||||
|
||||
# parse_qs returns a dict where each value is a list, so let's flatten it
|
||||
existing_query_params = {
|
||||
k: v[0] if len(v) == 1 else v for k, v in existing_query_params.items()
|
||||
}
|
||||
|
||||
def get_merged_query_parameters(
|
||||
existing_url: httpx.URL, request_query_params: Dict[str, Union[str, list]]
|
||||
) -> Dict[str, Union[str, List[str]]]:
|
||||
# Get the existing query params from the target URL
|
||||
existing_query_string = existing_url.query.decode("utf-8")
|
||||
existing_query_params = parse_qs(existing_query_string)
|
||||
# Merge the query params, giving priority to the existing ones
|
||||
return {**request_query_params, **existing_query_params}
|
||||
|
||||
# parse_qs returns a dict where each value is a list, so let's flatten it
|
||||
existing_query_params = {
|
||||
k: v[0] if len(v) == 1 else v for k, v in existing_query_params.items()
|
||||
}
|
||||
|
||||
# Merge the query params, giving priority to the existing ones
|
||||
return {**request_query_params, **existing_query_params}
|
||||
|
||||
|
||||
def get_endpoint_type(url: str) -> EndpointType:
|
||||
if ("generateContent") in url or ("streamGenerateContent") in url:
|
||||
return EndpointType.VERTEX_AI
|
||||
elif ("api.anthropic.com") in url:
|
||||
return EndpointType.ANTHROPIC
|
||||
return EndpointType.GENERIC
|
||||
@staticmethod
|
||||
async def _make_non_streaming_http_request(
|
||||
request: Request,
|
||||
async_client: httpx.AsyncClient,
|
||||
url: str,
|
||||
headers: dict,
|
||||
requested_query_params: Optional[dict] = None,
|
||||
custom_body: Optional[dict] = None,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Make a non-streaming HTTP request
|
||||
|
||||
If request is GET, don't include a JSON body
|
||||
"""
|
||||
if request.method == "GET":
|
||||
response = await async_client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=requested_query_params,
|
||||
)
|
||||
else:
|
||||
response = await async_client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=requested_query_params,
|
||||
json=custom_body,
|
||||
)
|
||||
return response
|
||||
|
||||
async def pass_through_request( # noqa: PLR0915
|
||||
request: Request,
|
||||
|
@ -339,23 +372,27 @@ async def pass_through_request( # noqa: PLR0915
|
|||
|
||||
url = httpx.URL(target)
|
||||
headers = custom_headers
|
||||
headers = forward_headers_from_request(
|
||||
headers = HttpPassThroughEndpointHelpers.forward_headers_from_request(
|
||||
request=request, headers=headers, forward_headers=forward_headers
|
||||
)
|
||||
|
||||
|
||||
if merge_query_params:
|
||||
|
||||
# Create a new URL with the merged query params
|
||||
url = url.copy_with(
|
||||
query=urlencode(
|
||||
get_merged_query_parameters(
|
||||
HttpPassThroughEndpointHelpers.get_merged_query_parameters(
|
||||
existing_url=url,
|
||||
request_query_params=dict(request.query_params),
|
||||
)
|
||||
).encode("ascii")
|
||||
)
|
||||
|
||||
endpoint_type: EndpointType = get_endpoint_type(str(url))
|
||||
|
||||
endpoint_type: EndpointType = HttpPassThroughEndpointHelpers.get_endpoint_type(
|
||||
str(url)
|
||||
)
|
||||
|
||||
_parsed_body = None
|
||||
if custom_body:
|
||||
|
@ -472,7 +509,7 @@ async def pass_through_request( # noqa: PLR0915
|
|||
passthrough_success_handler_obj=pass_through_endpoint_logging,
|
||||
url_route=str(url),
|
||||
),
|
||||
headers=get_response_headers(
|
||||
headers=HttpPassThroughEndpointHelpers.get_response_headers(
|
||||
headers=response.headers,
|
||||
litellm_call_id=litellm_call_id,
|
||||
),
|
||||
|
@ -487,13 +524,21 @@ async def pass_through_request( # noqa: PLR0915
|
|||
)
|
||||
verbose_proxy_logger.debug("request body: {}".format(_parsed_body))
|
||||
|
||||
response = await async_client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=requested_query_params,
|
||||
json=_parsed_body,
|
||||
)
|
||||
if request.method == "GET":
|
||||
response = await async_client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=requested_query_params,
|
||||
)
|
||||
else:
|
||||
response = await async_client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=requested_query_params,
|
||||
json=_parsed_body,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("response.headers= %s", response.headers)
|
||||
|
||||
|
@ -515,7 +560,7 @@ async def pass_through_request( # noqa: PLR0915
|
|||
passthrough_success_handler_obj=pass_through_endpoint_logging,
|
||||
url_route=str(url),
|
||||
),
|
||||
headers=get_response_headers(
|
||||
headers=HttpPassThroughEndpointHelpers.get_response_headers(
|
||||
headers=response.headers,
|
||||
litellm_call_id=litellm_call_id,
|
||||
),
|
||||
|
@ -555,7 +600,7 @@ async def pass_through_request( # noqa: PLR0915
|
|||
return Response(
|
||||
content=content,
|
||||
status_code=response.status_code,
|
||||
headers=get_response_headers(
|
||||
headers=HttpPassThroughEndpointHelpers.get_response_headers(
|
||||
headers=response.headers,
|
||||
litellm_call_id=litellm_call_id,
|
||||
),
|
||||
|
@ -592,12 +637,19 @@ def _init_kwargs_for_pass_through_endpoint(
|
|||
) -> dict:
|
||||
_parsed_body = _parsed_body or {}
|
||||
_litellm_metadata: Optional[dict] = _parsed_body.pop("litellm_metadata", None)
|
||||
_metadata = {
|
||||
"user_api_key": user_api_key_dict.api_key,
|
||||
"user_api_key_user_id": user_api_key_dict.user_id,
|
||||
"user_api_key_team_id": user_api_key_dict.team_id,
|
||||
"user_api_key_end_user_id": user_api_key_dict.end_user_id,
|
||||
}
|
||||
_metadata = dict(
|
||||
StandardLoggingUserAPIKeyMetadata(
|
||||
user_api_key_hash=user_api_key_dict.api_key,
|
||||
user_api_key_alias=user_api_key_dict.key_alias,
|
||||
user_api_key_user_email=user_api_key_dict.user_email,
|
||||
user_api_key_user_id=user_api_key_dict.user_id,
|
||||
user_api_key_team_id=user_api_key_dict.team_id,
|
||||
user_api_key_org_id=user_api_key_dict.org_id,
|
||||
user_api_key_team_alias=user_api_key_dict.team_alias,
|
||||
user_api_key_end_user_id=user_api_key_dict.end_user_id,
|
||||
)
|
||||
)
|
||||
_metadata["user_api_key"] = user_api_key_dict.api_key
|
||||
if _litellm_metadata:
|
||||
_metadata.update(_litellm_metadata)
|
||||
|
||||
|
@ -640,7 +692,7 @@ def create_pass_through_route(
|
|||
# check if target is an adapter.py or a url
|
||||
import uuid
|
||||
|
||||
from litellm.proxy.utils import get_instance_fn
|
||||
from litellm.proxy.types_utils.utils import get_instance_fn
|
||||
|
||||
try:
|
||||
if isinstance(target, CustomLogger):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue