diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index cb4ee84b5..29669a87d 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -41,7 +41,7 @@ def test_completion_custom_provider_model_name(): messages=messages, logger_fn=logger_fn, ) - # Add any assertions here to check the,response + # Add any assertions here to check the response print(response) print(response["choices"][0]["finish_reason"]) except litellm.Timeout as e: diff --git a/litellm/utils.py b/litellm/utils.py index 3ec882f0f..600d80599 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1434,9 +1434,7 @@ class Logging: model = self.model kwargs = self.model_call_details - input = kwargs.get( - "messages", kwargs.get("input", None) - ) + input = kwargs.get("messages", kwargs.get("input", None)) type = ( "embed" @@ -1444,7 +1442,7 @@ class Logging: else "llm" ) - # this only logs streaming once, complete_streaming_response exists i.e when stream ends + # this only logs streaming once, complete_streaming_response exists i.e when stream ends if self.stream: if "complete_streaming_response" not in kwargs: break @@ -1458,7 +1456,7 @@ class Logging: model=model, input=input, user_id=kwargs.get("user", None), - #user_props=self.model_call_details.get("user_props", None), + # user_props=self.model_call_details.get("user_props", None), extra=kwargs.get("optional_params", {}), response_obj=result, start_time=start_time, @@ -2064,8 +2062,6 @@ class Logging: else "llm" ) - - lunaryLogger.log_event( type=_type, event="error", @@ -2509,6 +2505,43 @@ def client(original_function): @wraps(original_function) def wrapper(*args, **kwargs): + # DO NOT MOVE THIS. It always needs to run first + # Check if this is an async function. If so only execute the async function + if ( + kwargs.get("acompletion", False) == True + or kwargs.get("aembedding", False) == True + or kwargs.get("aimg_generation", False) == True + or kwargs.get("amoderation", False) == True + or kwargs.get("atext_completion", False) == True + or kwargs.get("atranscription", False) == True + ): + # [OPTIONAL] CHECK MAX RETRIES / REQUEST + if litellm.num_retries_per_request is not None: + # check if previous_models passed in as ['litellm_params']['metadata]['previous_models'] + previous_models = kwargs.get("metadata", {}).get( + "previous_models", None + ) + if previous_models is not None: + if litellm.num_retries_per_request <= len(previous_models): + raise Exception(f"Max retries per request hit!") + + # MODEL CALL + result = original_function(*args, **kwargs) + if "stream" in kwargs and kwargs["stream"] == True: + if ( + "complete_response" in kwargs + and kwargs["complete_response"] == True + ): + chunks = [] + for idx, chunk in enumerate(result): + chunks.append(chunk) + return litellm.stream_chunk_builder( + chunks, messages=kwargs.get("messages", None) + ) + else: + return result + return result + # Prints Exactly what was passed to litellm function - don't execute any logic here - it should just print print_args_passed_to_litellm(original_function, args, kwargs) start_time = datetime.datetime.now() @@ -6178,9 +6211,9 @@ def validate_environment(model: Optional[str] = None) -> dict: def set_callbacks(callback_list, function_id=None): - + global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, langsmithLogger, dynamoLogger, s3Logger, dataDogLogger, prometheusLogger - + try: for callback in callback_list: print_verbose(f"callback: {callback}")