mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Bing Search Pass Thru
This commit is contained in:
parent
1e011b66d3
commit
fe0f9213af
4 changed files with 110 additions and 1 deletions
|
@ -2175,6 +2175,7 @@ class SpecialHeaders(enum.Enum):
|
||||||
azure_authorization = "API-Key"
|
azure_authorization = "API-Key"
|
||||||
anthropic_authorization = "x-api-key"
|
anthropic_authorization = "x-api-key"
|
||||||
google_ai_studio_authorization = "x-goog-api-key"
|
google_ai_studio_authorization = "x-goog-api-key"
|
||||||
|
bing_search_authorization = "Ocp-Apim-Subscription-Key"
|
||||||
|
|
||||||
|
|
||||||
class LitellmDataForBackendLLMCall(TypedDict, total=False):
|
class LitellmDataForBackendLLMCall(TypedDict, total=False):
|
||||||
|
|
|
@ -78,6 +78,11 @@ google_ai_studio_api_key_header = APIKeyHeader(
|
||||||
auto_error=False,
|
auto_error=False,
|
||||||
description="If google ai studio client used.",
|
description="If google ai studio client used.",
|
||||||
)
|
)
|
||||||
|
bing_search_header = APIKeyHeader(
|
||||||
|
name=SpecialHeaders.bing_search_authorization.value,
|
||||||
|
auto_error=False,
|
||||||
|
description="Custom header for Bing Search requests",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_bearer_token(
|
def _get_bearer_token(
|
||||||
|
@ -451,6 +456,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
azure_api_key_header: str,
|
azure_api_key_header: str,
|
||||||
anthropic_api_key_header: Optional[str],
|
anthropic_api_key_header: Optional[str],
|
||||||
google_ai_studio_api_key_header: Optional[str],
|
google_ai_studio_api_key_header: Optional[str],
|
||||||
|
bing_search_header: Optional[str],
|
||||||
request_data: dict,
|
request_data: dict,
|
||||||
) -> UserAPIKeyAuth:
|
) -> UserAPIKeyAuth:
|
||||||
|
|
||||||
|
@ -494,6 +500,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
api_key = anthropic_api_key_header
|
api_key = anthropic_api_key_header
|
||||||
elif isinstance(google_ai_studio_api_key_header, str):
|
elif isinstance(google_ai_studio_api_key_header, str):
|
||||||
api_key = google_ai_studio_api_key_header
|
api_key = google_ai_studio_api_key_header
|
||||||
|
elif isinstance(bing_search_header, str):
|
||||||
|
api_key = bing_search_header
|
||||||
elif pass_through_endpoints is not None:
|
elif pass_through_endpoints is not None:
|
||||||
for endpoint in pass_through_endpoints:
|
for endpoint in pass_through_endpoints:
|
||||||
if endpoint.get("path", "") == route:
|
if endpoint.get("path", "") == route:
|
||||||
|
@ -1317,6 +1325,7 @@ async def user_api_key_auth(
|
||||||
google_ai_studio_api_key_header: Optional[str] = fastapi.Security(
|
google_ai_studio_api_key_header: Optional[str] = fastapi.Security(
|
||||||
google_ai_studio_api_key_header
|
google_ai_studio_api_key_header
|
||||||
),
|
),
|
||||||
|
bing_search_header: Optional[str] = fastapi.Security(bing_search_header),
|
||||||
) -> UserAPIKeyAuth:
|
) -> UserAPIKeyAuth:
|
||||||
"""
|
"""
|
||||||
Parent function to authenticate user api key / jwt token.
|
Parent function to authenticate user api key / jwt token.
|
||||||
|
@ -1330,6 +1339,7 @@ async def user_api_key_auth(
|
||||||
azure_api_key_header=azure_api_key_header,
|
azure_api_key_header=azure_api_key_header,
|
||||||
anthropic_api_key_header=anthropic_api_key_header,
|
anthropic_api_key_header=anthropic_api_key_header,
|
||||||
google_ai_studio_api_key_header=google_ai_studio_api_key_header,
|
google_ai_studio_api_key_header=google_ai_studio_api_key_header,
|
||||||
|
bing_search_header=bing_search_header,
|
||||||
request_data=request_data,
|
request_data=request_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ import json
|
||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
from urllib.parse import urlencode, parse_qs
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||||
|
@ -310,6 +311,7 @@ async def pass_through_request( # noqa: PLR0915
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
custom_body: Optional[dict] = None,
|
custom_body: Optional[dict] = None,
|
||||||
forward_headers: Optional[bool] = False,
|
forward_headers: Optional[bool] = False,
|
||||||
|
merge_query_params: Optional[bool] = False,
|
||||||
query_params: Optional[dict] = None,
|
query_params: Optional[dict] = None,
|
||||||
stream: Optional[bool] = None,
|
stream: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
|
@ -325,6 +327,25 @@ async def pass_through_request( # noqa: PLR0915
|
||||||
request=request, headers=headers, forward_headers=forward_headers
|
request=request, headers=headers, forward_headers=forward_headers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if merge_query_params:
|
||||||
|
# Get the query params from the request
|
||||||
|
request_query_params = dict(request.query_params)
|
||||||
|
|
||||||
|
# Get the existing query params from the target URL
|
||||||
|
existing_query_string = url.query.decode("utf-8")
|
||||||
|
existing_query_params = parse_qs(existing_query_string)
|
||||||
|
|
||||||
|
# parse_qs returns a dict where each value is a list, so let's flatten it
|
||||||
|
existing_query_params = {
|
||||||
|
k: v[0] if len(v) == 1 else v for k, v in existing_query_params.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Merge the query params, giving priority to the existing ones
|
||||||
|
merged_query_params = {**request_query_params, **existing_query_params}
|
||||||
|
|
||||||
|
# Create a new URL with the merged query params
|
||||||
|
url = url.copy_with(query=urlencode(merged_query_params).encode("ascii"))
|
||||||
|
|
||||||
endpoint_type: EndpointType = get_endpoint_type(str(url))
|
endpoint_type: EndpointType = get_endpoint_type(str(url))
|
||||||
|
|
||||||
_parsed_body = None
|
_parsed_body = None
|
||||||
|
@ -604,6 +625,7 @@ def create_pass_through_route(
|
||||||
target: str,
|
target: str,
|
||||||
custom_headers: Optional[dict] = None,
|
custom_headers: Optional[dict] = None,
|
||||||
_forward_headers: Optional[bool] = False,
|
_forward_headers: Optional[bool] = False,
|
||||||
|
_merge_query_params: Optional[bool] = False,
|
||||||
dependencies: Optional[List] = None,
|
dependencies: Optional[List] = None,
|
||||||
):
|
):
|
||||||
# check if target is an adapter.py or a url
|
# check if target is an adapter.py or a url
|
||||||
|
@ -650,6 +672,7 @@ def create_pass_through_route(
|
||||||
custom_headers=custom_headers or {},
|
custom_headers=custom_headers or {},
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
forward_headers=_forward_headers,
|
forward_headers=_forward_headers,
|
||||||
|
merge_query_params=_merge_query_params,
|
||||||
query_params=query_params,
|
query_params=query_params,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
custom_body=custom_body,
|
custom_body=custom_body,
|
||||||
|
@ -679,6 +702,7 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list):
|
||||||
custom_headers=_custom_headers
|
custom_headers=_custom_headers
|
||||||
)
|
)
|
||||||
_forward_headers = endpoint.get("forward_headers", None)
|
_forward_headers = endpoint.get("forward_headers", None)
|
||||||
|
_merge_query_params = endpoint.get("merge_query_params", None)
|
||||||
_auth = endpoint.get("auth", None)
|
_auth = endpoint.get("auth", None)
|
||||||
_dependencies = None
|
_dependencies = None
|
||||||
if _auth is not None and str(_auth).lower() == "true":
|
if _auth is not None and str(_auth).lower() == "true":
|
||||||
|
@ -700,7 +724,12 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list):
|
||||||
app.add_api_route( # type: ignore
|
app.add_api_route( # type: ignore
|
||||||
path=_path,
|
path=_path,
|
||||||
endpoint=create_pass_through_route( # type: ignore
|
endpoint=create_pass_through_route( # type: ignore
|
||||||
_path, _target, _custom_headers, _forward_headers, _dependencies
|
_path,
|
||||||
|
_target,
|
||||||
|
_custom_headers,
|
||||||
|
_forward_headers,
|
||||||
|
_merge_query_params,
|
||||||
|
_dependencies,
|
||||||
),
|
),
|
||||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||||
dependencies=_dependencies,
|
dependencies=_dependencies,
|
||||||
|
|
|
@ -383,3 +383,72 @@ async def test_pass_through_endpoint_anthropic(client):
|
||||||
|
|
||||||
# Assert the response
|
# Assert the response
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pass_through_endpoint_bing(client, monkeypatch):
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
captured_requests = []
|
||||||
|
|
||||||
|
async def mock_bing_request(*args, **kwargs):
|
||||||
|
|
||||||
|
captured_requests.append((args, kwargs))
|
||||||
|
mock_response = httpx.Response(
|
||||||
|
200,
|
||||||
|
json={
|
||||||
|
"_type": "SearchResponse",
|
||||||
|
"queryContext": {"originalQuery": "bob barker"},
|
||||||
|
"webPages": {
|
||||||
|
"webSearchUrl": "https://www.bing.com/search?q=bob+barker",
|
||||||
|
"totalEstimatedMatches": 12000000,
|
||||||
|
"value": [],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
mock_response.request = Mock(spec=httpx.Request)
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
monkeypatch.setattr("httpx.AsyncClient.request", mock_bing_request)
|
||||||
|
|
||||||
|
# Define a pass-through endpoint
|
||||||
|
pass_through_endpoints = [
|
||||||
|
{
|
||||||
|
"path": "/bing/search",
|
||||||
|
"target": "https://api.bing.microsoft.com/v7.0/search?setLang=en-US&mkt=en-US",
|
||||||
|
"headers": {"Ocp-Apim-Subscription-Key": "XX"},
|
||||||
|
"forward_headers": True,
|
||||||
|
# Additional settings
|
||||||
|
"merge_query_params": True,
|
||||||
|
"auth": True,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "/bing/search-no-merge-params",
|
||||||
|
"target": "https://api.bing.microsoft.com/v7.0/search?setLang=en-US&mkt=en-US",
|
||||||
|
"headers": {"Ocp-Apim-Subscription-Key": "XX"},
|
||||||
|
"forward_headers": True,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Initialize the pass-through endpoint
|
||||||
|
await initialize_pass_through_endpoints(pass_through_endpoints)
|
||||||
|
general_settings: Optional[dict] = (
|
||||||
|
getattr(litellm.proxy.proxy_server, "general_settings", {}) or {}
|
||||||
|
)
|
||||||
|
general_settings.update({"pass_through_endpoints": pass_through_endpoints})
|
||||||
|
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
|
||||||
|
|
||||||
|
# Make 2 requests thru the pass-through endpoint
|
||||||
|
client.get("/bing/search?q=bob+barker")
|
||||||
|
client.get("/bing/search-no-merge-params?q=bob+barker")
|
||||||
|
|
||||||
|
first_transformed_url = captured_requests[0][1]["url"]
|
||||||
|
second_transformed_url = captured_requests[1][1]["url"]
|
||||||
|
|
||||||
|
# Assert the response
|
||||||
|
assert (
|
||||||
|
first_transformed_url
|
||||||
|
== "https://api.bing.microsoft.com/v7.0/search?q=bob+barker&setLang=en-US&mkt=en-US"
|
||||||
|
and second_transformed_url
|
||||||
|
== "https://api.bing.microsoft.com/v7.0/search?setLang=en-US&mkt=en-US"
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue