mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Bing Search Pass Thru
This commit is contained in:
parent
f66029470f
commit
75b713974f
4 changed files with 110 additions and 1 deletions
|
@ -4,6 +4,7 @@ import json
|
|||
from base64 import b64encode
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urlencode, parse_qs
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
|
@ -310,6 +311,7 @@ async def pass_through_request( # noqa: PLR0915
|
|||
user_api_key_dict: UserAPIKeyAuth,
|
||||
custom_body: Optional[dict] = None,
|
||||
forward_headers: Optional[bool] = False,
|
||||
merge_query_params: Optional[bool] = False,
|
||||
query_params: Optional[dict] = None,
|
||||
stream: Optional[bool] = None,
|
||||
):
|
||||
|
@ -325,6 +327,25 @@ async def pass_through_request( # noqa: PLR0915
|
|||
request=request, headers=headers, forward_headers=forward_headers
|
||||
)
|
||||
|
||||
if merge_query_params:
|
||||
# Get the query params from the request
|
||||
request_query_params = dict(request.query_params)
|
||||
|
||||
# Get the existing query params from the target URL
|
||||
existing_query_string = url.query.decode("utf-8")
|
||||
existing_query_params = parse_qs(existing_query_string)
|
||||
|
||||
# parse_qs returns a dict where each value is a list, so let's flatten it
|
||||
existing_query_params = {
|
||||
k: v[0] if len(v) == 1 else v for k, v in existing_query_params.items()
|
||||
}
|
||||
|
||||
# Merge the query params, giving priority to the existing ones
|
||||
merged_query_params = {**request_query_params, **existing_query_params}
|
||||
|
||||
# Create a new URL with the merged query params
|
||||
url = url.copy_with(query=urlencode(merged_query_params).encode("ascii"))
|
||||
|
||||
endpoint_type: EndpointType = get_endpoint_type(str(url))
|
||||
|
||||
_parsed_body = None
|
||||
|
@ -604,6 +625,7 @@ def create_pass_through_route(
|
|||
target: str,
|
||||
custom_headers: Optional[dict] = None,
|
||||
_forward_headers: Optional[bool] = False,
|
||||
_merge_query_params: Optional[bool] = False,
|
||||
dependencies: Optional[List] = None,
|
||||
):
|
||||
# check if target is an adapter.py or a url
|
||||
|
@ -650,6 +672,7 @@ def create_pass_through_route(
|
|||
custom_headers=custom_headers or {},
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
forward_headers=_forward_headers,
|
||||
merge_query_params=_merge_query_params,
|
||||
query_params=query_params,
|
||||
stream=stream,
|
||||
custom_body=custom_body,
|
||||
|
@ -679,6 +702,7 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list):
|
|||
custom_headers=_custom_headers
|
||||
)
|
||||
_forward_headers = endpoint.get("forward_headers", None)
|
||||
_merge_query_params = endpoint.get("merge_query_params", None)
|
||||
_auth = endpoint.get("auth", None)
|
||||
_dependencies = None
|
||||
if _auth is not None and str(_auth).lower() == "true":
|
||||
|
@ -700,7 +724,12 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list):
|
|||
app.add_api_route( # type: ignore
|
||||
path=_path,
|
||||
endpoint=create_pass_through_route( # type: ignore
|
||||
_path, _target, _custom_headers, _forward_headers, _dependencies
|
||||
_path,
|
||||
_target,
|
||||
_custom_headers,
|
||||
_forward_headers,
|
||||
_merge_query_params,
|
||||
_dependencies,
|
||||
),
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
dependencies=_dependencies,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue