diff --git a/litellm/tests/test_rules.py b/litellm/tests/test_rules.py index 7e7a984a9..0997143bf 100644 --- a/litellm/tests/test_rules.py +++ b/litellm/tests/test_rules.py @@ -127,3 +127,14 @@ def test_post_call_rule_streaming(): print(type(e)) print(vars(e)) assert e.message == "This violates LiteLLM Proxy Rules. Response too short" + + +def test_post_call_processing_error_async_response(): + response = asyncio.run( + acompletion( + model="command-nightly", # Just used as an example + messages=[{"content": "Hello, how are you?", "role": "user"}], + api_base="https://openai-proxy.berriai.repl.co", # Just used as an example + custom_llm_provider="openai", + ) + ) diff --git a/litellm/utils.py b/litellm/utils.py index ee8c34a62..d31bed6cb 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2465,6 +2465,14 @@ def client(original_function): ) raise e + def check_coroutine(value) -> bool: + if inspect.iscoroutine(value): + return True + elif inspect.iscoroutinefunction(value): + return True + else: + return False + def post_call_processing(original_response, model): try: if original_response is None: @@ -2475,11 +2483,15 @@ def client(original_function): call_type == CallTypes.completion.value or call_type == CallTypes.acompletion.value ): - model_response = original_response["choices"][0]["message"][ - "content" - ] - ### POST-CALL RULES ### - rules_obj.post_call_rules(input=model_response, model=model) + is_coroutine = check_coroutine(original_function) + if is_coroutine == True: + pass + else: + model_response = original_response["choices"][0]["message"][ + "content" + ] + ### POST-CALL RULES ### + rules_obj.post_call_rules(input=model_response, model=model) except Exception as e: raise e