From a332cc18619716ef2893b76966f16f02178f3cd1 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 22 Nov 2024 10:56:42 -0800 Subject: [PATCH] add support for using google ai sdk with litellm --- litellm/proxy/_types.py | 1 + litellm/proxy/auth/user_api_key_auth.py | 10 ++++++++++ .../llm_passthrough_endpoints.py | 7 +++++-- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 8b8dbf2e5..0c66868cc 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -2075,6 +2075,7 @@ class SpecialHeaders(enum.Enum): openai_authorization = "Authorization" azure_authorization = "API-Key" anthropic_authorization = "x-api-key" + google_ai_studio_authorization = "x-goog-api-key" class LitellmDataForBackendLLMCall(TypedDict, total=False): diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 669661e94..d19215245 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -95,6 +95,11 @@ anthropic_api_key_header = APIKeyHeader( auto_error=False, description="If anthropic client used.", ) +google_ai_studio_api_key_header = APIKeyHeader( + name=SpecialHeaders.google_ai_studio_authorization.value, + auto_error=False, + description="If google ai studio client used.", +) def _get_bearer_token( @@ -197,6 +202,9 @@ async def user_api_key_auth( # noqa: PLR0915 anthropic_api_key_header: Optional[str] = fastapi.Security( anthropic_api_key_header ), + google_ai_studio_api_key_header: Optional[str] = fastapi.Security( + google_ai_studio_api_key_header + ), ) -> UserAPIKeyAuth: from litellm.proxy.proxy_server import ( general_settings, @@ -233,6 +241,8 @@ async def user_api_key_auth( # noqa: PLR0915 api_key = azure_api_key_header elif isinstance(anthropic_api_key_header, str): api_key = anthropic_api_key_header + elif isinstance(google_ai_studio_api_key_header, str): + api_key = google_ai_studio_api_key_header elif pass_through_endpoints is not None: for endpoint in pass_through_endpoints: if endpoint.get("path", "") == route: diff --git a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py index 3f4643afc..7534e6ce9 100644 --- a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py @@ -61,10 +61,13 @@ async def gemini_proxy_route( fastapi_response: Response, ): ## CHECK FOR LITELLM API KEY IN THE QUERY PARAMS - ?..key=LITELLM_API_KEY - api_key = request.query_params.get("key") + google_ai_studio_api_key = request.query_params.get("key") or request.headers.get( + "x-goog-api-key" + ) user_api_key_dict = await user_api_key_auth( - request=request, api_key="Bearer {}".format(api_key) + request=request, + google_ai_studio_api_key_header=google_ai_studio_api_key, ) base_target_url = "https://generativelanguage.googleapis.com"