From 5291f380c9856cac325fb2a0896326647fd0da87 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 23 May 2024 09:30:53 -0700 Subject: [PATCH 1/2] feat - async_post_call_streaming_hook --- litellm/proxy/custom_callbacks1.py | 64 ++++++++++++++++++++++++++++++ litellm/proxy/proxy_config.yaml | 10 +---- litellm/proxy/proxy_server.py | 6 +++ litellm/proxy/utils.py | 21 ++++++++++ 4 files changed, 92 insertions(+), 9 deletions(-) create mode 100644 litellm/proxy/custom_callbacks1.py diff --git a/litellm/proxy/custom_callbacks1.py b/litellm/proxy/custom_callbacks1.py new file mode 100644 index 000000000..41962c9ab --- /dev/null +++ b/litellm/proxy/custom_callbacks1.py @@ -0,0 +1,64 @@ +from litellm.integrations.custom_logger import CustomLogger +import litellm +from litellm.proxy.proxy_server import UserAPIKeyAuth, DualCache +from typing import Optional, Literal + + +# This file includes the custom callbacks for LiteLLM Proxy +# Once defined, these can be passed in proxy_config.yaml +class MyCustomHandler( + CustomLogger +): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class + # Class variables or attributes + def __init__(self): + pass + + #### CALL HOOKS - proxy only #### + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: Literal[ + "completion", + "text_completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + ], + ): + return data + + async def async_post_call_failure_hook( + self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth + ): + pass + + async def async_post_call_success_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + response, + ): + # print("in async_post_call_success_hook") + pass + + async def async_moderation_hook( # call made in parallel to llm api call + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + call_type: Literal["completion", "embeddings", "image_generation"], + ): + pass + + async def async_post_call_streaming_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + response: str, + ): + # print("in async_post_call_streaming_hook") + pass + + +proxy_handler_instance = MyCustomHandler() diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 85634c5b8..ff7343c3a 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -20,16 +20,8 @@ model_list: api_base: https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings general_settings: - store_model_in_db: true master_key: sk-1234 alerting: ["slack"] litellm_settings: - success_callback: ["langfuse"] - failure_callback: ["langfuse"] - default_team_settings: - - team_id: 7bf09cd5-217a-40d4-8634-fc31d9b88bf4 - success_callback: ["langfuse"] - failure_callback: ["langfuse"] - langfuse_public_key: "os.environ/LANGFUSE_DEV_PUBLIC_KEY" - langfuse_secret_key: "os.environ/LANGFUSE_DEV_SK_KEY" + callbacks: custom_callbacks1.proxy_handler_instance \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index b52c9b249..cdc4ddfc7 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -3390,6 +3390,12 @@ async def async_data_generator( try: start_time = time.time() async for chunk in response: + + ### CALL HOOKS ### - modify outgoing data + chunk = await proxy_logging_obj.async_post_call_streaming_hook( + user_api_key_dict=user_api_key_dict, response=chunk + ) + chunk = chunk.model_dump_json(exclude_none=True) try: yield f"data: {chunk}\n\n" diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 8522b3259..b22134dab 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -516,6 +516,27 @@ class ProxyLogging: raise e return response + async def async_post_call_streaming_hook( + self, + response: Union[ModelResponse, EmbeddingResponse, ImageResponse], + user_api_key_dict: UserAPIKeyAuth, + ): + """ + Allow user to modify outgoing streaming data -> per chunk + + Covers: + 1. /chat/completions + """ + for callback in litellm.callbacks: + try: + if isinstance(callback, CustomLogger): + await callback.async_post_call_streaming_hook( + user_api_key_dict=user_api_key_dict, response=response + ) + except Exception as e: + raise e + return response + async def post_call_streaming_hook( self, response: str, From 908b481f0f5570d37f67dc5322058b84d5483434 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 23 May 2024 09:34:38 -0700 Subject: [PATCH 2/2] docs - fix litellm call hooks docs --- docs/my-website/docs/proxy/call_hooks.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/my-website/docs/proxy/call_hooks.md b/docs/my-website/docs/proxy/call_hooks.md index 3a8726e87..ce34e5ad6 100644 --- a/docs/my-website/docs/proxy/call_hooks.md +++ b/docs/my-website/docs/proxy/call_hooks.md @@ -17,6 +17,8 @@ This function is called just before a litellm completion call is made, and allow ```python from litellm.integrations.custom_logger import CustomLogger import litellm +from litellm.proxy.proxy_server import UserAPIKeyAuth, DualCache +from typing import Optional, Literal # This file includes the custom callbacks for LiteLLM Proxy # Once defined, these can be passed in proxy_config.yaml @@ -34,7 +36,7 @@ class MyCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observabilit "image_generation", "moderation", "audio_transcription", - ]) -> Optional[dict, str, Exception]: + ]): data["model"] = "my-new-model" return data