diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 5ab66e4fcf..536531496f 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -2207,7 +2207,7 @@ class SpecialHeaders(enum.Enum): azure_authorization = "API-Key" anthropic_authorization = "x-api-key" google_ai_studio_authorization = "x-goog-api-key" - bing_search_authorization = "Ocp-Apim-Subscription-Key" + azure_apim_authorization = "Ocp-Apim-Subscription-Key" class LitellmDataForBackendLLMCall(TypedDict, total=False): diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index d3caa1194f..948e37be8a 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -75,10 +75,10 @@ google_ai_studio_api_key_header = APIKeyHeader( auto_error=False, description="If google ai studio client used.", ) -bing_search_header = APIKeyHeader( - name=SpecialHeaders.bing_search_authorization.value, +azure_apim_header = APIKeyHeader( + name=SpecialHeaders.azure_apim_authorization.value, auto_error=False, - description="Custom header for Bing Search requests", + description="The default name of the subscription key header of Azure", ) @@ -289,7 +289,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 azure_api_key_header: str, anthropic_api_key_header: Optional[str], google_ai_studio_api_key_header: Optional[str], - bing_search_header: Optional[str], + azure_apim_header: Optional[str], request_data: dict, ) -> UserAPIKeyAuth: @@ -333,8 +333,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 api_key = anthropic_api_key_header elif isinstance(google_ai_studio_api_key_header, str): api_key = google_ai_studio_api_key_header - elif isinstance(bing_search_header, str): - api_key = bing_search_header + elif isinstance(azure_apim_header, str): + api_key = azure_apim_header elif pass_through_endpoints is not None: for endpoint in pass_through_endpoints: if endpoint.get("path", "") == route: @@ -1160,7 +1160,7 @@ async def user_api_key_auth( google_ai_studio_api_key_header: Optional[str] = fastapi.Security( google_ai_studio_api_key_header ), - bing_search_header: Optional[str] = fastapi.Security(bing_search_header), + azure_apim_header: Optional[str] = fastapi.Security(azure_apim_header), ) -> UserAPIKeyAuth: """ Parent function to authenticate user api key / jwt token. @@ -1174,7 +1174,7 @@ async def user_api_key_auth( azure_api_key_header=azure_api_key_header, anthropic_api_key_header=anthropic_api_key_header, google_ai_studio_api_key_header=google_ai_studio_api_key_header, - bing_search_header=bing_search_header, + azure_apim_header=azure_apim_header, request_data=request_data, ) diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index fcbdfc1fc6..e919cb1a60 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -3,7 +3,7 @@ import asyncio import json from base64 import b64encode from datetime import datetime -from typing import List, Optional +from typing import List, Optional, Union, Dict from urllib.parse import urlencode, parse_qs import httpx @@ -296,6 +296,22 @@ def get_response_headers( return return_headers +def get_merged_query_parameters( + existing_url: httpx.URL, request_query_params: Dict[str, Union[str, list]] +) -> Dict[str, Union[str, List[str]]]: + # Get the existing query params from the target URL + existing_query_string = existing_url.query.decode("utf-8") + existing_query_params = parse_qs(existing_query_string) + + # parse_qs returns a dict where each value is a list, so let's flatten it + existing_query_params = { + k: v[0] if len(v) == 1 else v for k, v in existing_query_params.items() + } + + # Merge the query params, giving priority to the existing ones + return {**request_query_params, **existing_query_params} + + def get_endpoint_type(url: str) -> EndpointType: if ("generateContent") in url or ("streamGenerateContent") in url: return EndpointType.VERTEX_AI @@ -328,23 +344,16 @@ async def pass_through_request( # noqa: PLR0915 ) if merge_query_params: - # Get the query params from the request - request_query_params = dict(request.query_params) - - # Get the existing query params from the target URL - existing_query_string = url.query.decode("utf-8") - existing_query_params = parse_qs(existing_query_string) - - # parse_qs returns a dict where each value is a list, so let's flatten it - existing_query_params = { - k: v[0] if len(v) == 1 else v for k, v in existing_query_params.items() - } - - # Merge the query params, giving priority to the existing ones - merged_query_params = {**request_query_params, **existing_query_params} # Create a new URL with the merged query params - url = url.copy_with(query=urlencode(merged_query_params).encode("ascii")) + url = url.copy_with( + query=urlencode( + get_merged_query_parameters( + existing_url=url, + request_query_params=dict(request.query_params), + ) + ).encode("ascii") + ) endpoint_type: EndpointType = get_endpoint_type(str(url))