From 1856ac585d7f07f44b7c039f6eee3b8aa4ffdff7 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 17 Aug 2024 16:57:40 -0700 Subject: [PATCH] feat(pass_through_endpoints.py): add pass-through support for all cohere endpoints --- litellm/litellm_core_utils/litellm_logging.py | 2 +- .../pass_through_endpoints.py | 2 +- .../google_ai_studio_endpoints.py | 42 +++++++++++++++++++ 3 files changed, 44 insertions(+), 2 deletions(-) diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 77837a7898..df02594953 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -2321,7 +2321,7 @@ def get_standard_logging_object_payload( model_map_value=_model_cost_information, ) except Exception: - verbose_logger.warning( + verbose_logger.debug( # keep in debug otherwise it will trigger on every call "Model is not mapped in model cost map. Defaulting to None model_cost_information for standard_logging_payload" ) model_cost_information = StandardLoggingModelInformation( diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 966593551d..5b9e04d1f5 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -320,7 +320,7 @@ async def pass_through_request( call_type="pass_through_endpoint", ) - async_client = httpx.AsyncClient() + async_client = httpx.AsyncClient(timeout=600) # create logging object start_time = time.time() diff --git a/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py b/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py index 9eb0f9bd1f..c798e091fb 100644 --- a/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py @@ -94,3 +94,45 @@ async def gemini_proxy_route( ) return received_value + + +@router.api_route("/cohere/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]) +async def cohere_proxy_route( + endpoint: str, + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + base_target_url = "https://api.cohere.com" + encoded_endpoint = httpx.URL(endpoint).path + + # Ensure endpoint starts with '/' for proper URL construction + if not encoded_endpoint.startswith("/"): + encoded_endpoint = "/" + encoded_endpoint + + # Construct the full target URL using httpx + base_url = httpx.URL(base_target_url) + updated_url = base_url.copy_with(path=encoded_endpoint) + + # Add or update query parameters + cohere_api_key = litellm.utils.get_secret(secret_name="COHERE_API_KEY") + + ## check for streaming + is_streaming_request = False + if "stream" in str(updated_url): + is_streaming_request = True + + ## CREATE PASS-THROUGH + endpoint_func = create_pass_through_route( + endpoint=endpoint, + target=str(updated_url), + custom_headers={"Authorization": "Bearer {}".format(cohere_api_key)}, + ) # dynamically construct pass-through endpoint based on incoming path + received_value = await endpoint_func( + request, + fastapi_response, + user_api_key_dict, + stream=is_streaming_request, + ) + + return received_value