diff --git a/docs/my-website/docs/proxy/call_hooks.md b/docs/my-website/docs/proxy/call_hooks.md index 3a8726e879..ce34e5ad6b 100644 --- a/docs/my-website/docs/proxy/call_hooks.md +++ b/docs/my-website/docs/proxy/call_hooks.md @@ -17,6 +17,8 @@ This function is called just before a litellm completion call is made, and allow ```python from litellm.integrations.custom_logger import CustomLogger import litellm +from litellm.proxy.proxy_server import UserAPIKeyAuth, DualCache +from typing import Optional, Literal # This file includes the custom callbacks for LiteLLM Proxy # Once defined, these can be passed in proxy_config.yaml @@ -34,7 +36,7 @@ class MyCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observabilit "image_generation", "moderation", "audio_transcription", - ]) -> Optional[dict, str, Exception]: + ]): data["model"] = "my-new-model" return data diff --git a/litellm/proxy/custom_callbacks1.py b/litellm/proxy/custom_callbacks1.py new file mode 100644 index 0000000000..41962c9aba --- /dev/null +++ b/litellm/proxy/custom_callbacks1.py @@ -0,0 +1,64 @@ +from litellm.integrations.custom_logger import CustomLogger +import litellm +from litellm.proxy.proxy_server import UserAPIKeyAuth, DualCache +from typing import Optional, Literal + + +# This file includes the custom callbacks for LiteLLM Proxy +# Once defined, these can be passed in proxy_config.yaml +class MyCustomHandler( + CustomLogger +): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class + # Class variables or attributes + def __init__(self): + pass + + #### CALL HOOKS - proxy only #### + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: Literal[ + "completion", + "text_completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + ], + ): + return data + + async def async_post_call_failure_hook( + self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth + ): + pass + + async def async_post_call_success_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + response, + ): + # print("in async_post_call_success_hook") + pass + + async def async_moderation_hook( # call made in parallel to llm api call + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + call_type: Literal["completion", "embeddings", "image_generation"], + ): + pass + + async def async_post_call_streaming_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + response: str, + ): + # print("in async_post_call_streaming_hook") + pass + + +proxy_handler_instance = MyCustomHandler() diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 85634c5b86..ff7343c3a5 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -20,16 +20,8 @@ model_list: api_base: https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings general_settings: - store_model_in_db: true master_key: sk-1234 alerting: ["slack"] litellm_settings: - success_callback: ["langfuse"] - failure_callback: ["langfuse"] - default_team_settings: - - team_id: 7bf09cd5-217a-40d4-8634-fc31d9b88bf4 - success_callback: ["langfuse"] - failure_callback: ["langfuse"] - langfuse_public_key: "os.environ/LANGFUSE_DEV_PUBLIC_KEY" - langfuse_secret_key: "os.environ/LANGFUSE_DEV_SK_KEY" + callbacks: custom_callbacks1.proxy_handler_instance \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 4045c7d914..062234577c 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -3402,6 +3402,12 @@ async def async_data_generator( try: start_time = time.time() async for chunk in response: + + ### CALL HOOKS ### - modify outgoing data + chunk = await proxy_logging_obj.async_post_call_streaming_hook( + user_api_key_dict=user_api_key_dict, response=chunk + ) + chunk = chunk.model_dump_json(exclude_none=True) try: yield f"data: {chunk}\n\n" diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 8522b32593..b22134dab3 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -516,6 +516,27 @@ class ProxyLogging: raise e return response + async def async_post_call_streaming_hook( + self, + response: Union[ModelResponse, EmbeddingResponse, ImageResponse], + user_api_key_dict: UserAPIKeyAuth, + ): + """ + Allow user to modify outgoing streaming data -> per chunk + + Covers: + 1. /chat/completions + """ + for callback in litellm.callbacks: + try: + if isinstance(callback, CustomLogger): + await callback.async_post_call_streaming_hook( + user_api_key_dict=user_api_key_dict, response=response + ) + except Exception as e: + raise e + return response + async def post_call_streaming_hook( self, response: str,