From 75b713974fb4e19f0cc4c9981feb6ce71d88f1ff Mon Sep 17 00:00:00 2001 From: Steve Farthing Date: Mon, 27 Jan 2025 08:58:04 -0500 Subject: [PATCH] Bing Search Pass Thru --- litellm/proxy/_types.py | 1 + litellm/proxy/auth/user_api_key_auth.py | 10 +++ .../pass_through_endpoints.py | 31 ++++++++- .../test_pass_through_endpoints.py | 69 +++++++++++++++++++ 4 files changed, 110 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 6b2569eb3c..5ab66e4fcf 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -2207,6 +2207,7 @@ class SpecialHeaders(enum.Enum): azure_authorization = "API-Key" anthropic_authorization = "x-api-key" google_ai_studio_authorization = "x-goog-api-key" + bing_search_authorization = "Ocp-Apim-Subscription-Key" class LitellmDataForBackendLLMCall(TypedDict, total=False): diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 84334b1db9..d3caa1194f 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -75,6 +75,11 @@ google_ai_studio_api_key_header = APIKeyHeader( auto_error=False, description="If google ai studio client used.", ) +bing_search_header = APIKeyHeader( + name=SpecialHeaders.bing_search_authorization.value, + auto_error=False, + description="Custom header for Bing Search requests", +) def _get_bearer_token( @@ -284,6 +289,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 azure_api_key_header: str, anthropic_api_key_header: Optional[str], google_ai_studio_api_key_header: Optional[str], + bing_search_header: Optional[str], request_data: dict, ) -> UserAPIKeyAuth: @@ -327,6 +333,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 api_key = anthropic_api_key_header elif isinstance(google_ai_studio_api_key_header, str): api_key = google_ai_studio_api_key_header + elif isinstance(bing_search_header, str): + api_key = bing_search_header elif pass_through_endpoints is not None: for endpoint in pass_through_endpoints: if endpoint.get("path", "") == route: @@ -1152,6 +1160,7 @@ async def user_api_key_auth( google_ai_studio_api_key_header: Optional[str] = fastapi.Security( google_ai_studio_api_key_header ), + bing_search_header: Optional[str] = fastapi.Security(bing_search_header), ) -> UserAPIKeyAuth: """ Parent function to authenticate user api key / jwt token. @@ -1165,6 +1174,7 @@ async def user_api_key_auth( azure_api_key_header=azure_api_key_header, anthropic_api_key_header=anthropic_api_key_header, google_ai_studio_api_key_header=google_ai_studio_api_key_header, + bing_search_header=bing_search_header, request_data=request_data, ) diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 970af05f6d..fcbdfc1fc6 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -4,6 +4,7 @@ import json from base64 import b64encode from datetime import datetime from typing import List, Optional +from urllib.parse import urlencode, parse_qs import httpx from fastapi import APIRouter, Depends, HTTPException, Request, Response, status @@ -310,6 +311,7 @@ async def pass_through_request( # noqa: PLR0915 user_api_key_dict: UserAPIKeyAuth, custom_body: Optional[dict] = None, forward_headers: Optional[bool] = False, + merge_query_params: Optional[bool] = False, query_params: Optional[dict] = None, stream: Optional[bool] = None, ): @@ -325,6 +327,25 @@ async def pass_through_request( # noqa: PLR0915 request=request, headers=headers, forward_headers=forward_headers ) + if merge_query_params: + # Get the query params from the request + request_query_params = dict(request.query_params) + + # Get the existing query params from the target URL + existing_query_string = url.query.decode("utf-8") + existing_query_params = parse_qs(existing_query_string) + + # parse_qs returns a dict where each value is a list, so let's flatten it + existing_query_params = { + k: v[0] if len(v) == 1 else v for k, v in existing_query_params.items() + } + + # Merge the query params, giving priority to the existing ones + merged_query_params = {**request_query_params, **existing_query_params} + + # Create a new URL with the merged query params + url = url.copy_with(query=urlencode(merged_query_params).encode("ascii")) + endpoint_type: EndpointType = get_endpoint_type(str(url)) _parsed_body = None @@ -604,6 +625,7 @@ def create_pass_through_route( target: str, custom_headers: Optional[dict] = None, _forward_headers: Optional[bool] = False, + _merge_query_params: Optional[bool] = False, dependencies: Optional[List] = None, ): # check if target is an adapter.py or a url @@ -650,6 +672,7 @@ def create_pass_through_route( custom_headers=custom_headers or {}, user_api_key_dict=user_api_key_dict, forward_headers=_forward_headers, + merge_query_params=_merge_query_params, query_params=query_params, stream=stream, custom_body=custom_body, @@ -679,6 +702,7 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list): custom_headers=_custom_headers ) _forward_headers = endpoint.get("forward_headers", None) + _merge_query_params = endpoint.get("merge_query_params", None) _auth = endpoint.get("auth", None) _dependencies = None if _auth is not None and str(_auth).lower() == "true": @@ -700,7 +724,12 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list): app.add_api_route( # type: ignore path=_path, endpoint=create_pass_through_route( # type: ignore - _path, _target, _custom_headers, _forward_headers, _dependencies + _path, + _target, + _custom_headers, + _forward_headers, + _merge_query_params, + _dependencies, ), methods=["GET", "POST", "PUT", "DELETE", "PATCH"], dependencies=_dependencies, diff --git a/tests/local_testing/test_pass_through_endpoints.py b/tests/local_testing/test_pass_through_endpoints.py index 7e9dfcfc79..8914b9877e 100644 --- a/tests/local_testing/test_pass_through_endpoints.py +++ b/tests/local_testing/test_pass_through_endpoints.py @@ -383,3 +383,72 @@ async def test_pass_through_endpoint_anthropic(client): # Assert the response assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_pass_through_endpoint_bing(client, monkeypatch): + import litellm + + captured_requests = [] + + async def mock_bing_request(*args, **kwargs): + + captured_requests.append((args, kwargs)) + mock_response = httpx.Response( + 200, + json={ + "_type": "SearchResponse", + "queryContext": {"originalQuery": "bob barker"}, + "webPages": { + "webSearchUrl": "https://www.bing.com/search?q=bob+barker", + "totalEstimatedMatches": 12000000, + "value": [], + }, + }, + ) + mock_response.request = Mock(spec=httpx.Request) + return mock_response + + monkeypatch.setattr("httpx.AsyncClient.request", mock_bing_request) + + # Define a pass-through endpoint + pass_through_endpoints = [ + { + "path": "/bing/search", + "target": "https://api.bing.microsoft.com/v7.0/search?setLang=en-US&mkt=en-US", + "headers": {"Ocp-Apim-Subscription-Key": "XX"}, + "forward_headers": True, + # Additional settings + "merge_query_params": True, + "auth": True, + }, + { + "path": "/bing/search-no-merge-params", + "target": "https://api.bing.microsoft.com/v7.0/search?setLang=en-US&mkt=en-US", + "headers": {"Ocp-Apim-Subscription-Key": "XX"}, + "forward_headers": True, + }, + ] + + # Initialize the pass-through endpoint + await initialize_pass_through_endpoints(pass_through_endpoints) + general_settings: Optional[dict] = ( + getattr(litellm.proxy.proxy_server, "general_settings", {}) or {} + ) + general_settings.update({"pass_through_endpoints": pass_through_endpoints}) + setattr(litellm.proxy.proxy_server, "general_settings", general_settings) + + # Make 2 requests thru the pass-through endpoint + client.get("/bing/search?q=bob+barker") + client.get("/bing/search-no-merge-params?q=bob+barker") + + first_transformed_url = captured_requests[0][1]["url"] + second_transformed_url = captured_requests[1][1]["url"] + + # Assert the response + assert ( + first_transformed_url + == "https://api.bing.microsoft.com/v7.0/search?q=bob+barker&setLang=en-US&mkt=en-US" + and second_transformed_url + == "https://api.bing.microsoft.com/v7.0/search?setLang=en-US&mkt=en-US" + )