From bc0023a40905d2e5dc38e9dbefd9cc0f800ec434 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 17 Aug 2024 10:46:59 -0700 Subject: [PATCH] feat(google_ai_studio_endpoints.py): support pass-through endpoint for all google ai studio requests New Feature --- litellm/cost_calculator.py | 9 ++- litellm/litellm_core_utils/litellm_logging.py | 3 +- .../pass_through_endpoints.py | 62 +++++++++++---- litellm/proxy/proxy_server.py | 4 + .../google_ai_studio_endpoints.py | 79 +++++++++++++++++++ litellm/tests/test_proxy_server.py | 49 ++++++++++++ 6 files changed, 186 insertions(+), 20 deletions(-) diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index 4b5ac51db9..d6f9adb008 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -412,7 +412,7 @@ def get_replicate_completion_pricing(completion_response=None, total_time=0.0): def _select_model_name_for_cost_calc( model: Optional[str], - completion_response: Union[BaseModel, dict], + completion_response: Union[BaseModel, dict, str], base_model: Optional[str] = None, custom_pricing: Optional[bool] = None, ) -> Optional[str]: @@ -428,7 +428,12 @@ def _select_model_name_for_cost_calc( if base_model is not None: return base_model - return_model = model or completion_response.get("model", "") # type: ignore + return_model = model + if isinstance(completion_response, str): + return return_model + + elif return_model is None: + return_model = completion_response.get("model", "") # type: ignore if hasattr(completion_response, "_hidden_params"): if ( completion_response._hidden_params.get("model", None) is not None diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 74c6d0db01..77837a7898 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -274,6 +274,7 @@ class Logging: headers = {} data = additional_args.get("complete_input_dict", {}) api_base = str(additional_args.get("api_base", "")) + query_params = additional_args.get("query_params", {}) if "key=" in api_base: # Find the position of "key=" in the string key_index = api_base.find("key=") + 4 @@ -2362,7 +2363,7 @@ def get_standard_logging_object_payload( return payload except Exception as e: - verbose_logger.error( + verbose_logger.exception( "Error creating standard logging object - {}".format(str(e)) ) return None diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index fd18c707f3..5ee6cc6d6b 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -273,6 +273,7 @@ async def pass_through_request( custom_headers: dict, user_api_key_dict: UserAPIKeyAuth, forward_headers: Optional[bool] = False, + query_params: Optional[dict] = None, ): try: import time @@ -308,23 +309,9 @@ async def pass_through_request( ) async_client = httpx.AsyncClient() - response = await async_client.request( - method=request.method, - url=url, - headers=headers, - params=request.query_params, - json=_parsed_body, - ) - if response.status_code >= 300: - raise HTTPException(status_code=response.status_code, detail=response.text) - - content = await response.aread() - - ## LOG SUCCESS - start_time = time.time() - end_time = time.time() # create logging object + start_time = time.time() logging_obj = Logging( model="unknown", messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}], @@ -334,6 +321,7 @@ async def pass_through_request( litellm_call_id=str(uuid.uuid4()), function_id="1245", ) + # done for supporting 'parallel_request_limiter.py' with pass-through endpoints kwargs = { "litellm_params": { @@ -354,6 +342,44 @@ async def pass_through_request( call_type="pass_through_endpoint", ) + # combine url with query params for logging + + requested_query_params = query_params or request.query_params.__dict__ + requested_query_params_str = "&".join( + f"{k}={v}" for k, v in requested_query_params.items() + ) + + if "?" in str(url): + logging_url = str(url) + "&" + requested_query_params_str + else: + logging_url = str(url) + "?" + requested_query_params_str + + logging_obj.pre_call( + input=[{"role": "user", "content": "no-message-pass-through-endpoint"}], + api_key="", + additional_args={ + "complete_input_dict": _parsed_body, + "api_base": logging_url, + "headers": headers, + }, + ) + + response = await async_client.request( + method=request.method, + url=url, + headers=headers, + params=requested_query_params, + json=_parsed_body, + ) + + if response.status_code >= 300: + raise HTTPException(status_code=response.status_code, detail=response.text) + + content = await response.aread() + + ## LOG SUCCESS + end_time = time.time() + await logging_obj.async_success_handler( result="", start_time=start_time, @@ -431,17 +457,19 @@ def create_pass_through_route( except Exception: verbose_proxy_logger.debug("Defaulting to target being a url.") - async def endpoint_func( + async def endpoint_func( # type: ignore request: Request, fastapi_response: Response, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + query_params: Optional[dict] = None, ): - return await pass_through_request( + return await pass_through_request( # type: ignore request=request, target=target, custom_headers=custom_headers or {}, user_api_key_dict=user_api_key_dict, forward_headers=_forward_headers, + query_params=query_params, ) return endpoint_func diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index f09ae7d350..60df87e729 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -227,6 +227,9 @@ from litellm.proxy.utils import ( send_email, update_spend, ) +from litellm.proxy.vertex_ai_endpoints.google_ai_studio_endpoints import ( + router as gemini_router, +) from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import router as vertex_router from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import set_default_vertex_config from litellm.router import ( @@ -9704,6 +9707,7 @@ def cleanup_router_config_variables(): app.include_router(router) app.include_router(fine_tuning_router) app.include_router(vertex_router) +app.include_router(gemini_router) app.include_router(pass_through_router) app.include_router(health_router) app.include_router(key_management_router) diff --git a/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py b/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py index 1bb966964d..3b6105b447 100644 --- a/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py @@ -3,3 +3,82 @@ What is this? Google AI Studio Pass-Through Endpoints """ + +""" +1. Create pass-through endpoints for any LITELLM_BASE_URL/gemini/ map to https://generativelanguage.googleapis.com/ +""" + +import ast +import asyncio +import traceback +from datetime import datetime, timedelta, timezone +from typing import List, Optional +from urllib.parse import urlencode + +import fastapi +import httpx +from fastapi import ( + APIRouter, + Depends, + File, + Form, + Header, + HTTPException, + Request, + Response, + UploadFile, + status, +) +from starlette.datastructures import QueryParams + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.batches.main import FileObject +from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance +from litellm.proxy._types import * +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( + create_pass_through_route, +) + +router = APIRouter() +default_vertex_config = None + + +@router.api_route("/gemini/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]) +async def gemini_proxy_route( + endpoint: str, + request: Request, + fastapi_response: Response, +): + ## CHECK FOR LITELLM API KEY IN THE QUERY PARAMS - ?..key=LITELLM_API_KEY + api_key = request.query_params.get("key") + + user_api_key_dict = await user_api_key_auth( + request=request, api_key="Bearer {}".format(api_key) + ) + + base_target_url = "https://generativelanguage.googleapis.com" + encoded_endpoint = httpx.URL(endpoint).path + + # Ensure endpoint starts with '/' for proper URL construction + if not encoded_endpoint.startswith("/"): + encoded_endpoint = "/" + encoded_endpoint + + # Construct the full target URL using httpx + base_url = httpx.URL(base_target_url) + updated_url = base_url.copy_with(path=encoded_endpoint) + + # Add or update query parameters + gemini_api_key = litellm.utils.get_secret(secret_name="GEMINI_API_KEY") + # Merge query parameters, giving precedence to those in updated_url + merged_params = dict(request.query_params) + merged_params.update({"key": gemini_api_key}) + + endpoint_func = create_pass_through_route( + endpoint=endpoint, + target=str(updated_url), + ) # dynamically construct pass-through endpoint based on incoming path + return await endpoint_func( + request, fastapi_response, user_api_key_dict, query_params=merged_params + ) diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index 9a1c091267..28f3aad632 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -1166,3 +1166,52 @@ async def test_add_callback_via_key_litellm_pre_call_utils(prisma_client): assert new_data["success_callback"] == ["langfuse"] assert "langfuse_public_key" in new_data assert "langfuse_secret_key" in new_data + + +@pytest.mark.asyncio +async def test_gemini_pass_through_endpoint(): + from starlette.datastructures import URL + + from litellm.proxy.vertex_ai_endpoints.google_ai_studio_endpoints import ( + Request, + Response, + gemini_proxy_route, + ) + + body = b""" + { + "contents": [{ + "parts":[{ + "text": "The quick brown fox jumps over the lazy dog." + }] + }] + } + """ + + # Construct the scope dictionary + scope = { + "type": "http", + "method": "POST", + "path": "/gemini/v1beta/models/gemini-1.5-flash:countTokens", + "query_string": b"key=sk-1234", + "headers": [ + (b"content-type", b"application/json"), + ], + } + + # Create a new Request object + async def async_receive(): + return {"type": "http.request", "body": body, "more_body": False} + + request = Request( + scope=scope, + receive=async_receive, + ) + + resp = await gemini_proxy_route( + endpoint="v1beta/models/gemini-1.5-flash:countTokens?key=sk-1234", + request=request, + fastapi_response=Response(), + ) + + print(resp.body)