Bing Search Pass Thru

This commit is contained in:
Steve Farthing 2025-01-27 08:58:04 -05:00
parent f66029470f
commit 75b713974f
4 changed files with 110 additions and 1 deletions

View file

@ -2207,6 +2207,7 @@ class SpecialHeaders(enum.Enum):
azure_authorization = "API-Key" azure_authorization = "API-Key"
anthropic_authorization = "x-api-key" anthropic_authorization = "x-api-key"
google_ai_studio_authorization = "x-goog-api-key" google_ai_studio_authorization = "x-goog-api-key"
bing_search_authorization = "Ocp-Apim-Subscription-Key"
class LitellmDataForBackendLLMCall(TypedDict, total=False): class LitellmDataForBackendLLMCall(TypedDict, total=False):

View file

@ -75,6 +75,11 @@ google_ai_studio_api_key_header = APIKeyHeader(
auto_error=False, auto_error=False,
description="If google ai studio client used.", description="If google ai studio client used.",
) )
bing_search_header = APIKeyHeader(
name=SpecialHeaders.bing_search_authorization.value,
auto_error=False,
description="Custom header for Bing Search requests",
)
def _get_bearer_token( def _get_bearer_token(
@ -284,6 +289,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
azure_api_key_header: str, azure_api_key_header: str,
anthropic_api_key_header: Optional[str], anthropic_api_key_header: Optional[str],
google_ai_studio_api_key_header: Optional[str], google_ai_studio_api_key_header: Optional[str],
bing_search_header: Optional[str],
request_data: dict, request_data: dict,
) -> UserAPIKeyAuth: ) -> UserAPIKeyAuth:
@ -327,6 +333,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
api_key = anthropic_api_key_header api_key = anthropic_api_key_header
elif isinstance(google_ai_studio_api_key_header, str): elif isinstance(google_ai_studio_api_key_header, str):
api_key = google_ai_studio_api_key_header api_key = google_ai_studio_api_key_header
elif isinstance(bing_search_header, str):
api_key = bing_search_header
elif pass_through_endpoints is not None: elif pass_through_endpoints is not None:
for endpoint in pass_through_endpoints: for endpoint in pass_through_endpoints:
if endpoint.get("path", "") == route: if endpoint.get("path", "") == route:
@ -1152,6 +1160,7 @@ async def user_api_key_auth(
google_ai_studio_api_key_header: Optional[str] = fastapi.Security( google_ai_studio_api_key_header: Optional[str] = fastapi.Security(
google_ai_studio_api_key_header google_ai_studio_api_key_header
), ),
bing_search_header: Optional[str] = fastapi.Security(bing_search_header),
) -> UserAPIKeyAuth: ) -> UserAPIKeyAuth:
""" """
Parent function to authenticate user api key / jwt token. Parent function to authenticate user api key / jwt token.
@ -1165,6 +1174,7 @@ async def user_api_key_auth(
azure_api_key_header=azure_api_key_header, azure_api_key_header=azure_api_key_header,
anthropic_api_key_header=anthropic_api_key_header, anthropic_api_key_header=anthropic_api_key_header,
google_ai_studio_api_key_header=google_ai_studio_api_key_header, google_ai_studio_api_key_header=google_ai_studio_api_key_header,
bing_search_header=bing_search_header,
request_data=request_data, request_data=request_data,
) )

View file

@ -4,6 +4,7 @@ import json
from base64 import b64encode from base64 import b64encode
from datetime import datetime from datetime import datetime
from typing import List, Optional from typing import List, Optional
from urllib.parse import urlencode, parse_qs
import httpx import httpx
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
@ -310,6 +311,7 @@ async def pass_through_request( # noqa: PLR0915
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
custom_body: Optional[dict] = None, custom_body: Optional[dict] = None,
forward_headers: Optional[bool] = False, forward_headers: Optional[bool] = False,
merge_query_params: Optional[bool] = False,
query_params: Optional[dict] = None, query_params: Optional[dict] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
): ):
@ -325,6 +327,25 @@ async def pass_through_request( # noqa: PLR0915
request=request, headers=headers, forward_headers=forward_headers request=request, headers=headers, forward_headers=forward_headers
) )
if merge_query_params:
# Get the query params from the request
request_query_params = dict(request.query_params)
# Get the existing query params from the target URL
existing_query_string = url.query.decode("utf-8")
existing_query_params = parse_qs(existing_query_string)
# parse_qs returns a dict where each value is a list, so let's flatten it
existing_query_params = {
k: v[0] if len(v) == 1 else v for k, v in existing_query_params.items()
}
# Merge the query params, giving priority to the existing ones
merged_query_params = {**request_query_params, **existing_query_params}
# Create a new URL with the merged query params
url = url.copy_with(query=urlencode(merged_query_params).encode("ascii"))
endpoint_type: EndpointType = get_endpoint_type(str(url)) endpoint_type: EndpointType = get_endpoint_type(str(url))
_parsed_body = None _parsed_body = None
@ -604,6 +625,7 @@ def create_pass_through_route(
target: str, target: str,
custom_headers: Optional[dict] = None, custom_headers: Optional[dict] = None,
_forward_headers: Optional[bool] = False, _forward_headers: Optional[bool] = False,
_merge_query_params: Optional[bool] = False,
dependencies: Optional[List] = None, dependencies: Optional[List] = None,
): ):
# check if target is an adapter.py or a url # check if target is an adapter.py or a url
@ -650,6 +672,7 @@ def create_pass_through_route(
custom_headers=custom_headers or {}, custom_headers=custom_headers or {},
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,
forward_headers=_forward_headers, forward_headers=_forward_headers,
merge_query_params=_merge_query_params,
query_params=query_params, query_params=query_params,
stream=stream, stream=stream,
custom_body=custom_body, custom_body=custom_body,
@ -679,6 +702,7 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list):
custom_headers=_custom_headers custom_headers=_custom_headers
) )
_forward_headers = endpoint.get("forward_headers", None) _forward_headers = endpoint.get("forward_headers", None)
_merge_query_params = endpoint.get("merge_query_params", None)
_auth = endpoint.get("auth", None) _auth = endpoint.get("auth", None)
_dependencies = None _dependencies = None
if _auth is not None and str(_auth).lower() == "true": if _auth is not None and str(_auth).lower() == "true":
@ -700,7 +724,12 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list):
app.add_api_route( # type: ignore app.add_api_route( # type: ignore
path=_path, path=_path,
endpoint=create_pass_through_route( # type: ignore endpoint=create_pass_through_route( # type: ignore
_path, _target, _custom_headers, _forward_headers, _dependencies _path,
_target,
_custom_headers,
_forward_headers,
_merge_query_params,
_dependencies,
), ),
methods=["GET", "POST", "PUT", "DELETE", "PATCH"], methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
dependencies=_dependencies, dependencies=_dependencies,

View file

@ -383,3 +383,72 @@ async def test_pass_through_endpoint_anthropic(client):
# Assert the response # Assert the response
assert response.status_code == 200 assert response.status_code == 200
@pytest.mark.asyncio
async def test_pass_through_endpoint_bing(client, monkeypatch):
import litellm
captured_requests = []
async def mock_bing_request(*args, **kwargs):
captured_requests.append((args, kwargs))
mock_response = httpx.Response(
200,
json={
"_type": "SearchResponse",
"queryContext": {"originalQuery": "bob barker"},
"webPages": {
"webSearchUrl": "https://www.bing.com/search?q=bob+barker",
"totalEstimatedMatches": 12000000,
"value": [],
},
},
)
mock_response.request = Mock(spec=httpx.Request)
return mock_response
monkeypatch.setattr("httpx.AsyncClient.request", mock_bing_request)
# Define a pass-through endpoint
pass_through_endpoints = [
{
"path": "/bing/search",
"target": "https://api.bing.microsoft.com/v7.0/search?setLang=en-US&mkt=en-US",
"headers": {"Ocp-Apim-Subscription-Key": "XX"},
"forward_headers": True,
# Additional settings
"merge_query_params": True,
"auth": True,
},
{
"path": "/bing/search-no-merge-params",
"target": "https://api.bing.microsoft.com/v7.0/search?setLang=en-US&mkt=en-US",
"headers": {"Ocp-Apim-Subscription-Key": "XX"},
"forward_headers": True,
},
]
# Initialize the pass-through endpoint
await initialize_pass_through_endpoints(pass_through_endpoints)
general_settings: Optional[dict] = (
getattr(litellm.proxy.proxy_server, "general_settings", {}) or {}
)
general_settings.update({"pass_through_endpoints": pass_through_endpoints})
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
# Make 2 requests thru the pass-through endpoint
client.get("/bing/search?q=bob+barker")
client.get("/bing/search-no-merge-params?q=bob+barker")
first_transformed_url = captured_requests[0][1]["url"]
second_transformed_url = captured_requests[1][1]["url"]
# Assert the response
assert (
first_transformed_url
== "https://api.bing.microsoft.com/v7.0/search?q=bob+barker&setLang=en-US&mkt=en-US"
and second_transformed_url
== "https://api.bing.microsoft.com/v7.0/search?setLang=en-US&mkt=en-US"
)