diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index d508825922..accb4f80f1 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -4,7 +4,7 @@ import dotenv, os from litellm.proxy._types import UserAPIKeyAuth from litellm.caching import DualCache - +from litellm.utils import ModelResponse from typing import Literal, Union, Optional import traceback @@ -64,8 +64,17 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, - call_type: Literal["completion", "embeddings", "image_generation"], - ): + call_type: Literal[ + "completion", + "text_completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + ], + ) -> Optional[ + Union[Exception, str, dict] + ]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm pass async def async_post_call_failure_hook( diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index ef5380d089..6ad3e598a5 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -3761,11 +3761,24 @@ async def chat_completion( data["litellm_logging_obj"] = logging_obj - ### CALL HOOKS ### - modify incoming data before calling the model - data = await proxy_logging_obj.pre_call_hook( + ### CALL HOOKS ### - modify/reject incoming data before calling the model + data = await proxy_logging_obj.pre_call_hook( # type: ignore user_api_key_dict=user_api_key_dict, data=data, call_type="completion" ) + if isinstance(data, litellm.ModelResponse): + return data + elif isinstance(data, litellm.CustomStreamWrapper): + selected_data_generator = select_data_generator( + response=data, + user_api_key_dict=user_api_key_dict, + request_data={}, + ) + + return StreamingResponse( + selected_data_generator, + media_type="text/event-stream", + ) tasks = [] tasks.append( proxy_logging_obj.during_call_hook( @@ -3998,10 +4011,24 @@ async def completion( data["model"] = litellm.model_alias_map[data["model"]] ### CALL HOOKS ### - modify incoming data before calling the model - data = await proxy_logging_obj.pre_call_hook( - user_api_key_dict=user_api_key_dict, data=data, call_type="completion" + data = await proxy_logging_obj.pre_call_hook( # type: ignore + user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion" ) + if isinstance(data, litellm.TextCompletionResponse): + return data + elif isinstance(data, litellm.TextCompletionStreamWrapper): + selected_data_generator = select_data_generator( + response=data, + user_api_key_dict=user_api_key_dict, + request_data={}, + ) + + return StreamingResponse( + selected_data_generator, + media_type="text/event-stream", + ) + ### ROUTE THE REQUESTs ### router_model_names = llm_router.model_names if llm_router is not None else [] # skip router if user passed their key diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 09e772e10b..fc49ebc7fe 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -19,7 +19,16 @@ from litellm.proxy.hooks.parallel_request_limiter import ( _PROXY_MaxParallelRequestsHandler, ) from litellm._service_logger import ServiceLogging, ServiceTypes -from litellm import ModelResponse, EmbeddingResponse, ImageResponse +from litellm import ( + ModelResponse, + EmbeddingResponse, + ImageResponse, + TranscriptionResponse, + TextCompletionResponse, + CustomStreamWrapper, + TextCompletionStreamWrapper, +) +from litellm.utils import ModelResponseIterator from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck @@ -32,6 +41,7 @@ from email.mime.text import MIMEText from email.mime.multipart import MIMEMultipart from datetime import datetime, timedelta from litellm.integrations.slack_alerting import SlackAlerting +from typing_extensions import overload def print_verbose(print_statement): @@ -176,18 +186,60 @@ class ProxyLogging: ) litellm.utils.set_callbacks(callback_list=callback_list) + # fmt: off + + @overload + async def pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + data: dict, + call_type: Literal["completion"] + ) -> Union[dict, ModelResponse, CustomStreamWrapper]: + ... + + @overload + async def pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + data: dict, + call_type: Literal["text_completion"] + ) -> Union[dict, TextCompletionResponse, TextCompletionStreamWrapper]: + ... + + @overload + async def pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + data: dict, + call_type: Literal["embeddings", + "image_generation", + "moderation", + "audio_transcription",] + ) -> dict: + ... + + # fmt: on + + # The actual implementation of the function async def pre_call_hook( self, user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal[ "completion", + "text_completion", "embeddings", "image_generation", "moderation", "audio_transcription", ], - ): + ) -> Union[ + dict, + ModelResponse, + TextCompletionResponse, + CustomStreamWrapper, + TextCompletionStreamWrapper, + ]: """ Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body. @@ -214,7 +266,58 @@ class ProxyLogging: call_type=call_type, ) if response is not None: - data = response + if isinstance(response, Exception): + raise response + elif isinstance(response, dict): + data = response + elif isinstance(response, str): + if call_type == "completion": + _chat_response = ModelResponse() + _chat_response.choices[0].message.content = response + + if ( + data.get("stream", None) is not None + and data["stream"] == True + ): + _iterator = ModelResponseIterator( + model_response=_chat_response + ) + return CustomStreamWrapper( + completion_stream=_iterator, + model=data.get("model", ""), + custom_llm_provider="cached_response", + logging_obj=data.get( + "litellm_logging_obj", None + ), + ) + return _response + elif call_type == "text_completion": + if ( + data.get("stream", None) is not None + and data["stream"] == True + ): + _chat_response = ModelResponse() + _chat_response.choices[0].message.content = response + + if ( + data.get("stream", None) is not None + and data["stream"] == True + ): + _iterator = ModelResponseIterator( + model_response=_chat_response + ) + return TextCompletionStreamWrapper( + completion_stream=_iterator, + model=data.get("model", ""), + ) + else: + _response = TextCompletionResponse() + _response.choices[0].text = response + return _response + else: + raise HTTPException( + status_code=400, detail={"error": response} + ) print_verbose(f"final data being sent to {call_type} call: {data}") return data diff --git a/litellm/utils.py b/litellm/utils.py index ac246fca6a..1e0485755c 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -12187,3 +12187,29 @@ def _add_key_name_and_team_to_alert(request_info: str, metadata: dict) -> str: return request_info except: return request_info + + +class ModelResponseIterator: + def __init__(self, model_response): + self.model_response = model_response + self.is_done = False + + # Sync iterator + def __iter__(self): + return self + + def __next__(self): + if self.is_done: + raise StopIteration + self.is_done = True + return self.model_response + + # Async iterator + def __aiter__(self): + return self + + async def __anext__(self): + if self.is_done: + raise StopAsyncIteration + self.is_done = True + return self.model_response