This commit is contained in:
Steve Farthing 2025-02-04 21:11:19 -05:00
parent fe0f9213af
commit 9724ee94df
3 changed files with 34 additions and 25 deletions

View file

@ -2175,7 +2175,7 @@ class SpecialHeaders(enum.Enum):
azure_authorization = "API-Key" azure_authorization = "API-Key"
anthropic_authorization = "x-api-key" anthropic_authorization = "x-api-key"
google_ai_studio_authorization = "x-goog-api-key" google_ai_studio_authorization = "x-goog-api-key"
bing_search_authorization = "Ocp-Apim-Subscription-Key" azure_apim_authorization = "Ocp-Apim-Subscription-Key"
class LitellmDataForBackendLLMCall(TypedDict, total=False): class LitellmDataForBackendLLMCall(TypedDict, total=False):

View file

@ -78,10 +78,10 @@ google_ai_studio_api_key_header = APIKeyHeader(
auto_error=False, auto_error=False,
description="If google ai studio client used.", description="If google ai studio client used.",
) )
bing_search_header = APIKeyHeader( azure_apim_header = APIKeyHeader(
name=SpecialHeaders.bing_search_authorization.value, name=SpecialHeaders.azure_apim_authorization.value,
auto_error=False, auto_error=False,
description="Custom header for Bing Search requests", description="The default name of the subscription key header of Azure",
) )
@ -456,7 +456,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
azure_api_key_header: str, azure_api_key_header: str,
anthropic_api_key_header: Optional[str], anthropic_api_key_header: Optional[str],
google_ai_studio_api_key_header: Optional[str], google_ai_studio_api_key_header: Optional[str],
bing_search_header: Optional[str], azure_apim_header: Optional[str],
request_data: dict, request_data: dict,
) -> UserAPIKeyAuth: ) -> UserAPIKeyAuth:
@ -500,8 +500,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
api_key = anthropic_api_key_header api_key = anthropic_api_key_header
elif isinstance(google_ai_studio_api_key_header, str): elif isinstance(google_ai_studio_api_key_header, str):
api_key = google_ai_studio_api_key_header api_key = google_ai_studio_api_key_header
elif isinstance(bing_search_header, str): elif isinstance(azure_apim_header, str):
api_key = bing_search_header api_key = azure_apim_header
elif pass_through_endpoints is not None: elif pass_through_endpoints is not None:
for endpoint in pass_through_endpoints: for endpoint in pass_through_endpoints:
if endpoint.get("path", "") == route: if endpoint.get("path", "") == route:
@ -1325,7 +1325,7 @@ async def user_api_key_auth(
google_ai_studio_api_key_header: Optional[str] = fastapi.Security( google_ai_studio_api_key_header: Optional[str] = fastapi.Security(
google_ai_studio_api_key_header google_ai_studio_api_key_header
), ),
bing_search_header: Optional[str] = fastapi.Security(bing_search_header), azure_apim_header: Optional[str] = fastapi.Security(azure_apim_header),
) -> UserAPIKeyAuth: ) -> UserAPIKeyAuth:
""" """
Parent function to authenticate user api key / jwt token. Parent function to authenticate user api key / jwt token.
@ -1339,7 +1339,7 @@ async def user_api_key_auth(
azure_api_key_header=azure_api_key_header, azure_api_key_header=azure_api_key_header,
anthropic_api_key_header=anthropic_api_key_header, anthropic_api_key_header=anthropic_api_key_header,
google_ai_studio_api_key_header=google_ai_studio_api_key_header, google_ai_studio_api_key_header=google_ai_studio_api_key_header,
bing_search_header=bing_search_header, azure_apim_header=azure_apim_header,
request_data=request_data, request_data=request_data,
) )

View file

@ -3,7 +3,7 @@ import asyncio
import json import json
from base64 import b64encode from base64 import b64encode
from datetime import datetime from datetime import datetime
from typing import List, Optional from typing import List, Optional, Union, Dict
from urllib.parse import urlencode, parse_qs from urllib.parse import urlencode, parse_qs
import httpx import httpx
@ -296,6 +296,22 @@ def get_response_headers(
return return_headers return return_headers
def get_merged_query_parameters(
existing_url: httpx.URL, request_query_params: Dict[str, Union[str, list]]
) -> Dict[str, Union[str, List[str]]]:
# Get the existing query params from the target URL
existing_query_string = existing_url.query.decode("utf-8")
existing_query_params = parse_qs(existing_query_string)
# parse_qs returns a dict where each value is a list, so let's flatten it
existing_query_params = {
k: v[0] if len(v) == 1 else v for k, v in existing_query_params.items()
}
# Merge the query params, giving priority to the existing ones
return {**request_query_params, **existing_query_params}
def get_endpoint_type(url: str) -> EndpointType: def get_endpoint_type(url: str) -> EndpointType:
if ("generateContent") in url or ("streamGenerateContent") in url: if ("generateContent") in url or ("streamGenerateContent") in url:
return EndpointType.VERTEX_AI return EndpointType.VERTEX_AI
@ -328,23 +344,16 @@ async def pass_through_request( # noqa: PLR0915
) )
if merge_query_params: if merge_query_params:
# Get the query params from the request
request_query_params = dict(request.query_params)
# Get the existing query params from the target URL
existing_query_string = url.query.decode("utf-8")
existing_query_params = parse_qs(existing_query_string)
# parse_qs returns a dict where each value is a list, so let's flatten it
existing_query_params = {
k: v[0] if len(v) == 1 else v for k, v in existing_query_params.items()
}
# Merge the query params, giving priority to the existing ones
merged_query_params = {**request_query_params, **existing_query_params}
# Create a new URL with the merged query params # Create a new URL with the merged query params
url = url.copy_with(query=urlencode(merged_query_params).encode("ascii")) url = url.copy_with(
query=urlencode(
get_merged_query_parameters(
existing_url=url,
request_query_params=dict(request.query_params),
)
).encode("ascii")
)
endpoint_type: EndpointType = get_endpoint_type(str(url)) endpoint_type: EndpointType = get_endpoint_type(str(url))