diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index 11d2fde8f2..5139723ca1 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -99,6 +99,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac "image_generation", "moderation", "audio_transcription", + "pass_through_endpoint", ], ) -> Optional[ Union[Exception, str, dict] diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index cf4a823c39..5d301ea269 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -12,11 +12,12 @@ model_list: general_settings: - alerting: ["slack"] - alerting_threshold: 10 master_key: sk-1234 pass_through_endpoints: - - path: "/v1/test-messages" # route you want to add to LiteLLM Proxy Server - target: litellm.adapters.anthropic_adapter.anthropic_adapter # URL this route should forward requests to - headers: # headers to forward to this URL - litellm_user_api_key: "x-my-test-key" \ No newline at end of file + - path: "/v1/rerank" + target: "https://api.cohere.com/v1/rerank" + auth: true # 👈 Key change to use LiteLLM Auth / Keys + headers: + Authorization: "bearer os.environ/COHERE_API_KEY" + content-type: application/json + accept: application/json \ No newline at end of file diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index a17fcb2c97..e9c8649d08 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -1,12 +1,16 @@ -from typing import Optional -import litellm, traceback, sys -from litellm.caching import DualCache -from litellm.proxy._types import UserAPIKeyAuth -from litellm.integrations.custom_logger import CustomLogger -from fastapi import HTTPException -from litellm._logging import verbose_proxy_logger -from litellm import ModelResponse +import sys +import traceback from datetime import datetime +from typing import Optional + +from fastapi import HTTPException + +import litellm +from litellm import ModelResponse +from litellm._logging import verbose_proxy_logger +from litellm.caching import DualCache +from litellm.integrations.custom_logger import CustomLogger +from litellm.proxy._types import UserAPIKeyAuth class _PROXY_MaxParallelRequestsHandler(CustomLogger): diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index b13e9834a2..3d17ba0d73 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -3,6 +3,7 @@ import asyncio import json import traceback from base64 import b64encode +from typing import Optional import httpx from fastapi import ( @@ -240,14 +241,28 @@ async def chat_completion_pass_through_endpoint( ) -async def pass_through_request(request: Request, target: str, custom_headers: dict): +async def pass_through_request( + request: Request, + target: str, + custom_headers: dict, + user_api_key_dict: UserAPIKeyAuth, +): try: + import time + import uuid + + from litellm.litellm_core_utils.litellm_logging import Logging + from litellm.proxy.proxy_server import proxy_logging_obj url = httpx.URL(target) headers = custom_headers request_body = await request.body() - _parsed_body = ast.literal_eval(request_body.decode("utf-8")) + body_str = request_body.decode() + try: + _parsed_body = ast.literal_eval(body_str) + except: + _parsed_body = json.loads(body_str) verbose_proxy_logger.debug( "Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format( @@ -255,6 +270,13 @@ async def pass_through_request(request: Request, target: str, custom_headers: di ) ) + ### CALL HOOKS ### - modify incoming data / reject request before calling the model + _parsed_body = await proxy_logging_obj.pre_call_hook( + user_api_key_dict=user_api_key_dict, + data=_parsed_body, + call_type="pass_through_endpoint", + ) + response = await async_client.request( method=request.method, url=url, @@ -267,6 +289,47 @@ async def pass_through_request(request: Request, target: str, custom_headers: di raise HTTPException(status_code=response.status_code, detail=response.text) content = await response.aread() + + ## LOG SUCCESS + start_time = time.time() + end_time = time.time() + # create logging object + logging_obj = Logging( + model="unknown", + messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}], + stream=False, + call_type="pass_through_endpoint", + start_time=start_time, + litellm_call_id=str(uuid.uuid4()), + function_id="1245", + ) + # done for supporting 'parallel_request_limiter.py' with pass-through endpoints + kwargs = { + "litellm_params": { + "metadata": { + "user_api_key": user_api_key_dict.api_key, + "user_api_key_user_id": user_api_key_dict.user_id, + "user_api_key_team_id": user_api_key_dict.team_id, + "user_api_key_end_user_id": user_api_key_dict.user_id, + } + }, + "call_type": "pass_through_endpoint", + } + logging_obj.update_environment_variables( + model="unknown", + user="unknown", + optional_params={}, + litellm_params=kwargs["litellm_params"], + call_type="pass_through_endpoint", + ) + + await logging_obj.async_success_handler( + result="", + start_time=start_time, + end_time=end_time, + cache_hit=False, + ) + return Response( content=content, status_code=response.status_code, @@ -274,8 +337,8 @@ async def pass_through_request(request: Request, target: str, custom_headers: di ) except Exception as e: verbose_proxy_logger.error( - "litellm.proxy.proxy_server.pass_through_endpoint(): Exception occured - {}".format( - str(e) + "litellm.proxy.proxy_server.pass_through_endpoint(): Exception occured - {}\n{}".format( + str(e), traceback.format_exc() ) ) verbose_proxy_logger.debug(traceback.format_exc()) @@ -296,7 +359,9 @@ async def pass_through_request(request: Request, target: str, custom_headers: di ) -def create_pass_through_route(endpoint, target: str, custom_headers=None): +def create_pass_through_route( + endpoint, target: str, custom_headers: Optional[dict] = None +): # check if target is an adapter.py or a url import uuid @@ -325,8 +390,17 @@ def create_pass_through_route(endpoint, target: str, custom_headers=None): except Exception: verbose_proxy_logger.warning("Defaulting to target being a url.") - async def endpoint_func(request: Request): # type: ignore - return await pass_through_request(request, target, custom_headers) + async def endpoint_func( + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + ): + return await pass_through_request( + request=request, + target=target, + custom_headers=custom_headers or {}, + user_api_key_dict=user_api_key_dict, + ) return endpoint_func @@ -349,7 +423,9 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list): if _auth is not None and str(_auth).lower() == "true": if premium_user is not True: raise ValueError( - f"Error Setting Authentication on Pass Through Endpoint: {CommonProxyErrors.not_premium_user}" + "Error Setting Authentication on Pass Through Endpoint: {}".format( + CommonProxyErrors.not_premium_user.value + ) ) _dependencies = [Depends(user_api_key_auth)] LiteLLMRoutes.openai_routes.value.append(_path) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index e129ccdcfa..17fc2ac411 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -299,6 +299,7 @@ class ProxyLogging: "image_generation", "moderation", "audio_transcription", + "pass_through_endpoint", ], ) -> dict: """