mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Feedback
This commit is contained in:
parent
fe0f9213af
commit
9724ee94df
3 changed files with 34 additions and 25 deletions
|
@ -2175,7 +2175,7 @@ class SpecialHeaders(enum.Enum):
|
|||
azure_authorization = "API-Key"
|
||||
anthropic_authorization = "x-api-key"
|
||||
google_ai_studio_authorization = "x-goog-api-key"
|
||||
bing_search_authorization = "Ocp-Apim-Subscription-Key"
|
||||
azure_apim_authorization = "Ocp-Apim-Subscription-Key"
|
||||
|
||||
|
||||
class LitellmDataForBackendLLMCall(TypedDict, total=False):
|
||||
|
|
|
@ -78,10 +78,10 @@ google_ai_studio_api_key_header = APIKeyHeader(
|
|||
auto_error=False,
|
||||
description="If google ai studio client used.",
|
||||
)
|
||||
bing_search_header = APIKeyHeader(
|
||||
name=SpecialHeaders.bing_search_authorization.value,
|
||||
azure_apim_header = APIKeyHeader(
|
||||
name=SpecialHeaders.azure_apim_authorization.value,
|
||||
auto_error=False,
|
||||
description="Custom header for Bing Search requests",
|
||||
description="The default name of the subscription key header of Azure",
|
||||
)
|
||||
|
||||
|
||||
|
@ -456,7 +456,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
azure_api_key_header: str,
|
||||
anthropic_api_key_header: Optional[str],
|
||||
google_ai_studio_api_key_header: Optional[str],
|
||||
bing_search_header: Optional[str],
|
||||
azure_apim_header: Optional[str],
|
||||
request_data: dict,
|
||||
) -> UserAPIKeyAuth:
|
||||
|
||||
|
@ -500,8 +500,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
|||
api_key = anthropic_api_key_header
|
||||
elif isinstance(google_ai_studio_api_key_header, str):
|
||||
api_key = google_ai_studio_api_key_header
|
||||
elif isinstance(bing_search_header, str):
|
||||
api_key = bing_search_header
|
||||
elif isinstance(azure_apim_header, str):
|
||||
api_key = azure_apim_header
|
||||
elif pass_through_endpoints is not None:
|
||||
for endpoint in pass_through_endpoints:
|
||||
if endpoint.get("path", "") == route:
|
||||
|
@ -1325,7 +1325,7 @@ async def user_api_key_auth(
|
|||
google_ai_studio_api_key_header: Optional[str] = fastapi.Security(
|
||||
google_ai_studio_api_key_header
|
||||
),
|
||||
bing_search_header: Optional[str] = fastapi.Security(bing_search_header),
|
||||
azure_apim_header: Optional[str] = fastapi.Security(azure_apim_header),
|
||||
) -> UserAPIKeyAuth:
|
||||
"""
|
||||
Parent function to authenticate user api key / jwt token.
|
||||
|
@ -1339,7 +1339,7 @@ async def user_api_key_auth(
|
|||
azure_api_key_header=azure_api_key_header,
|
||||
anthropic_api_key_header=anthropic_api_key_header,
|
||||
google_ai_studio_api_key_header=google_ai_studio_api_key_header,
|
||||
bing_search_header=bing_search_header,
|
||||
azure_apim_header=azure_apim_header,
|
||||
request_data=request_data,
|
||||
)
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import asyncio
|
|||
import json
|
||||
from base64 import b64encode
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union, Dict
|
||||
from urllib.parse import urlencode, parse_qs
|
||||
|
||||
import httpx
|
||||
|
@ -296,6 +296,22 @@ def get_response_headers(
|
|||
return return_headers
|
||||
|
||||
|
||||
def get_merged_query_parameters(
|
||||
existing_url: httpx.URL, request_query_params: Dict[str, Union[str, list]]
|
||||
) -> Dict[str, Union[str, List[str]]]:
|
||||
# Get the existing query params from the target URL
|
||||
existing_query_string = existing_url.query.decode("utf-8")
|
||||
existing_query_params = parse_qs(existing_query_string)
|
||||
|
||||
# parse_qs returns a dict where each value is a list, so let's flatten it
|
||||
existing_query_params = {
|
||||
k: v[0] if len(v) == 1 else v for k, v in existing_query_params.items()
|
||||
}
|
||||
|
||||
# Merge the query params, giving priority to the existing ones
|
||||
return {**request_query_params, **existing_query_params}
|
||||
|
||||
|
||||
def get_endpoint_type(url: str) -> EndpointType:
|
||||
if ("generateContent") in url or ("streamGenerateContent") in url:
|
||||
return EndpointType.VERTEX_AI
|
||||
|
@ -328,23 +344,16 @@ async def pass_through_request( # noqa: PLR0915
|
|||
)
|
||||
|
||||
if merge_query_params:
|
||||
# Get the query params from the request
|
||||
request_query_params = dict(request.query_params)
|
||||
|
||||
# Get the existing query params from the target URL
|
||||
existing_query_string = url.query.decode("utf-8")
|
||||
existing_query_params = parse_qs(existing_query_string)
|
||||
|
||||
# parse_qs returns a dict where each value is a list, so let's flatten it
|
||||
existing_query_params = {
|
||||
k: v[0] if len(v) == 1 else v for k, v in existing_query_params.items()
|
||||
}
|
||||
|
||||
# Merge the query params, giving priority to the existing ones
|
||||
merged_query_params = {**request_query_params, **existing_query_params}
|
||||
|
||||
# Create a new URL with the merged query params
|
||||
url = url.copy_with(query=urlencode(merged_query_params).encode("ascii"))
|
||||
url = url.copy_with(
|
||||
query=urlencode(
|
||||
get_merged_query_parameters(
|
||||
existing_url=url,
|
||||
request_query_params=dict(request.query_params),
|
||||
)
|
||||
).encode("ascii")
|
||||
)
|
||||
|
||||
endpoint_type: EndpointType = get_endpoint_type(str(url))
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue