This commit is contained in:
Steve Farthing 2025-02-04 21:11:19 -05:00
parent 75b713974f
commit e8a859720b
3 changed files with 34 additions and 25 deletions

View file

@ -3,7 +3,7 @@ import asyncio
import json
from base64 import b64encode
from datetime import datetime
from typing import List, Optional
from typing import List, Optional, Union, Dict
from urllib.parse import urlencode, parse_qs
import httpx
@ -296,6 +296,22 @@ def get_response_headers(
return return_headers
def get_merged_query_parameters(
existing_url: httpx.URL, request_query_params: Dict[str, Union[str, list]]
) -> Dict[str, Union[str, List[str]]]:
# Get the existing query params from the target URL
existing_query_string = existing_url.query.decode("utf-8")
existing_query_params = parse_qs(existing_query_string)
# parse_qs returns a dict where each value is a list, so let's flatten it
existing_query_params = {
k: v[0] if len(v) == 1 else v for k, v in existing_query_params.items()
}
# Merge the query params, giving priority to the existing ones
return {**request_query_params, **existing_query_params}
def get_endpoint_type(url: str) -> EndpointType:
if ("generateContent") in url or ("streamGenerateContent") in url:
return EndpointType.VERTEX_AI
@ -328,23 +344,16 @@ async def pass_through_request( # noqa: PLR0915
)
if merge_query_params:
# Get the query params from the request
request_query_params = dict(request.query_params)
# Get the existing query params from the target URL
existing_query_string = url.query.decode("utf-8")
existing_query_params = parse_qs(existing_query_string)
# parse_qs returns a dict where each value is a list, so let's flatten it
existing_query_params = {
k: v[0] if len(v) == 1 else v for k, v in existing_query_params.items()
}
# Merge the query params, giving priority to the existing ones
merged_query_params = {**request_query_params, **existing_query_params}
# Create a new URL with the merged query params
url = url.copy_with(query=urlencode(merged_query_params).encode("ascii"))
url = url.copy_with(
query=urlencode(
get_merged_query_parameters(
existing_url=url,
request_query_params=dict(request.query_params),
)
).encode("ascii")
)
endpoint_type: EndpointType = get_endpoint_type(str(url))