diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index ab79fe0bc2..a828c2caef 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -221,7 +221,7 @@ class AzureChatCompletion(BaseLLM): timeout: Any, azure_ad_token: Optional[str]=None, client=None, - ): + ): max_retries = data.pop("max_retries", 2) if not isinstance(max_retries, int): raise AzureOpenAIError(status_code=422, message="max retries must be an int") @@ -244,8 +244,7 @@ class AzureChatCompletion(BaseLLM): azure_client = client response = azure_client.chat.completions.create(**data) streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj) - for transformed_chunk in streamwrapper: - yield transformed_chunk + return streamwrapper async def async_streaming(self, logging_obj, diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 41570f0b5c..d2c5b6ff42 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -293,8 +293,7 @@ class OpenAIChatCompletion(BaseLLM): openai_client = client response = openai_client.chat.completions.create(**data) streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj) - for transformed_chunk in streamwrapper: - yield transformed_chunk + return streamwrapper async def async_streaming(self, logging_obj, diff --git a/litellm/tests/test_custom_logger.py b/litellm/tests/test_custom_logger.py index 1527423ca8..c31eba0c74 100644 --- a/litellm/tests/test_custom_logger.py +++ b/litellm/tests/test_custom_logger.py @@ -9,6 +9,7 @@ from litellm.integrations.custom_logger import CustomLogger class MyCustomHandler(CustomLogger): success: bool = False + failure: bool = False def log_pre_api_call(self, model, messages, kwargs): print(f"Pre-API Call") @@ -25,6 +26,7 @@ class MyCustomHandler(CustomLogger): def log_failure_event(self, kwargs, response_obj, start_time, end_time): print(f"On Failure") + self.failure = True def test_chat_openai(): try: @@ -51,10 +53,34 @@ def test_chat_openai(): pass -test_chat_openai() - +# test_chat_openai() +def test_completion_azure_stream_moderation_failure(): + try: + customHandler = MyCustomHandler() + litellm.callbacks = [customHandler] + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": "how do i kill someone", + }, + ] + try: + response = completion( + model="azure/chatgpt-v-2", messages=messages, stream=True + ) + for chunk in response: + print(f"chunk: {chunk}") + continue + except Exception as e: + print(e) + time.sleep(1) + assert customHandler.failure == True + except Exception as e: + pytest.fail(f"Error occurred: {e}") +test_completion_azure_stream_moderation_failure() # def custom_callback( diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 537c8c25fd..417bd64d9e 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -273,7 +273,7 @@ def test_completion_azure_function_calling_stream(): except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_azure_function_calling_stream() +# test_completion_azure_function_calling_stream() def test_completion_claude_stream(): try: