diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 400e23ad0..57bc6f854 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -75,9 +75,11 @@ class AzureOpenAIError(Exception): message, request: Optional[httpx.Request] = None, response: Optional[httpx.Response] = None, + headers: Optional[httpx.Headers] = None, ): self.status_code = status_code self.message = message + self.headers = headers if request: self.request = request else: @@ -593,7 +595,6 @@ class AzureChatCompletion(BaseLLM): client=None, ): super().completion() - exception_mapping_worked = False try: if model is None or messages is None: raise AzureOpenAIError( @@ -755,13 +756,13 @@ class AzureChatCompletion(BaseLLM): convert_tool_call_to_json_mode=json_mode, ) except AzureOpenAIError as e: - exception_mapping_worked = True raise e except Exception as e: - if hasattr(e, "status_code"): - raise AzureOpenAIError(status_code=e.status_code, message=str(e)) - else: - raise AzureOpenAIError(status_code=500, message=str(e)) + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + raise AzureOpenAIError( + status_code=status_code, message=str(e), headers=error_headers + ) async def acompletion( self, @@ -1005,10 +1006,11 @@ class AzureChatCompletion(BaseLLM): ) return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails except Exception as e: - if hasattr(e, "status_code"): - raise AzureOpenAIError(status_code=e.status_code, message=str(e)) - else: - raise AzureOpenAIError(status_code=500, message=str(e)) + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + raise AzureOpenAIError( + status_code=status_code, message=str(e), headers=error_headers + ) async def aembedding( self, @@ -1027,7 +1029,9 @@ class AzureChatCompletion(BaseLLM): openai_aclient = AsyncAzureOpenAI(**azure_client_params) else: openai_aclient = client - response = await openai_aclient.embeddings.create(**data, timeout=timeout) + response = await openai_aclient.embeddings.with_raw_response.create( + **data, timeout=timeout + ) stringified_response = response.model_dump() ## LOGGING logging_obj.post_call( @@ -1067,7 +1071,6 @@ class AzureChatCompletion(BaseLLM): aembedding=None, ): super().embedding() - exception_mapping_worked = False if self._client_session is None: self._client_session = self.create_client_session() try: @@ -1127,7 +1130,7 @@ class AzureChatCompletion(BaseLLM): else: azure_client = client ## COMPLETION CALL - response = azure_client.embeddings.create(**data, timeout=timeout) # type: ignore + response = azure_client.embeddings.with_raw_response.create(**data, timeout=timeout) # type: ignore ## LOGGING logging_obj.post_call( input=input, @@ -1138,13 +1141,13 @@ class AzureChatCompletion(BaseLLM): return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="embedding") # type: ignore except AzureOpenAIError as e: - exception_mapping_worked = True raise e except Exception as e: - if hasattr(e, "status_code"): - raise AzureOpenAIError(status_code=e.status_code, message=str(e)) - else: - raise AzureOpenAIError(status_code=500, message=str(e)) + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + raise AzureOpenAIError( + status_code=status_code, message=str(e), headers=error_headers + ) async def make_async_azure_httpx_request( self, diff --git a/litellm/llms/azure_text.py b/litellm/llms/azure_text.py index 72d6f134b..d8d7e9d14 100644 --- a/litellm/llms/azure_text.py +++ b/litellm/llms/azure_text.py @@ -33,9 +33,11 @@ class AzureOpenAIError(Exception): message, request: Optional[httpx.Request] = None, response: Optional[httpx.Response] = None, + headers: Optional[httpx.Headers] = None, ): self.status_code = status_code self.message = message + self.headers = headers if request: self.request = request else: @@ -311,13 +313,13 @@ class AzureTextCompletion(BaseLLM): ) ) except AzureOpenAIError as e: - exception_mapping_worked = True raise e except Exception as e: - if hasattr(e, "status_code"): - raise AzureOpenAIError(status_code=e.status_code, message=str(e)) - else: - raise AzureOpenAIError(status_code=500, message=str(e)) + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + raise AzureOpenAIError( + status_code=status_code, message=str(e), headers=error_headers + ) async def acompletion( self, @@ -387,10 +389,11 @@ class AzureTextCompletion(BaseLLM): exception_mapping_worked = True raise e except Exception as e: - if hasattr(e, "status_code"): - raise e - else: - raise AzureOpenAIError(status_code=500, message=str(e)) + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + raise AzureOpenAIError( + status_code=status_code, message=str(e), headers=error_headers + ) def streaming( self, @@ -443,7 +446,9 @@ class AzureTextCompletion(BaseLLM): "complete_input_dict": data, }, ) - response = azure_client.completions.create(**data, timeout=timeout) + response = azure_client.completions.with_raw_response.create( + **data, timeout=timeout + ) streamwrapper = CustomStreamWrapper( completion_stream=response, model=model, @@ -501,7 +506,9 @@ class AzureTextCompletion(BaseLLM): "complete_input_dict": data, }, ) - response = await azure_client.completions.create(**data, timeout=timeout) + response = await azure_client.completions.with_raw_response.create( + **data, timeout=timeout + ) # return response streamwrapper = CustomStreamWrapper( completion_stream=response, @@ -511,7 +518,8 @@ class AzureTextCompletion(BaseLLM): ) return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails except Exception as e: - if hasattr(e, "status_code"): - raise AzureOpenAIError(status_code=e.status_code, message=str(e)) - else: - raise AzureOpenAIError(status_code=500, message=str(e)) + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + raise AzureOpenAIError( + status_code=status_code, message=str(e), headers=error_headers + ) diff --git a/litellm/tests/test_exceptions.py b/litellm/tests/test_exceptions.py index ff5c4352f..fbc1dd047 100644 --- a/litellm/tests/test_exceptions.py +++ b/litellm/tests/test_exceptions.py @@ -887,16 +887,19 @@ def _pre_call_utils( [True, False], ) @pytest.mark.parametrize( - "model, call_type, streaming", + "provider, model, call_type, streaming", [ - ("text-embedding-ada-002", "embedding", None), - ("gpt-3.5-turbo", "chat_completion", False), - ("gpt-3.5-turbo", "chat_completion", True), - ("gpt-3.5-turbo-instruct", "completion", True), + ("openai", "text-embedding-ada-002", "embedding", None), + ("openai", "gpt-3.5-turbo", "chat_completion", False), + ("openai", "gpt-3.5-turbo", "chat_completion", True), + ("openai", "gpt-3.5-turbo-instruct", "completion", True), + ("azure", "azure/chatgpt-v-2", "chat_completion", True), + ("azure", "azure/text-embedding-ada-002", "embedding", True), + ("azure", "azure_text/gpt-3.5-turbo-instruct", "completion", True), ], ) @pytest.mark.asyncio -async def test_exception_with_headers(sync_mode, model, call_type, streaming): +async def test_exception_with_headers(sync_mode, provider, model, call_type, streaming): """ User feedback: litellm says "No deployments available for selected model, Try again in 60 seconds" but Azure says to retry in at most 9s @@ -908,9 +911,15 @@ async def test_exception_with_headers(sync_mode, model, call_type, streaming): import openai if sync_mode: - openai_client = openai.OpenAI(api_key="") + if provider == "openai": + openai_client = openai.OpenAI(api_key="") + elif provider == "azure": + openai_client = openai.AzureOpenAI(api_key="", base_url="") else: - openai_client = openai.AsyncOpenAI(api_key="") + if provider == "openai": + openai_client = openai.AsyncOpenAI(api_key="") + elif provider == "azure": + openai_client = openai.AsyncAzureOpenAI(api_key="", base_url="") data = {"model": model} data, original_function, mapped_target = _pre_call_utils( diff --git a/litellm/utils.py b/litellm/utils.py index d1582cb94..8ac094f59 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -8157,7 +8157,7 @@ def exception_type( model=model, request=original_exception.request, ) - elif custom_llm_provider == "azure": + elif custom_llm_provider == "azure" or custom_llm_provider == "azure_text": message = get_error_message(error_obj=original_exception) if message is None: if hasattr(original_exception, "message"):