diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index b2614faa0..8877c043f 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -398,43 +398,49 @@ class AzureChatCompletion(BaseLLM): azure_ad_token: Optional[str] = None, client=None, ): - # init AzureOpenAI Client - azure_client_params = { - "api_version": api_version, - "azure_endpoint": api_base, - "azure_deployment": model, - "http_client": litellm.client_session, - "max_retries": data.pop("max_retries", 2), - "timeout": timeout, - } - if api_key is not None: - azure_client_params["api_key"] = api_key - elif azure_ad_token is not None: - azure_client_params["azure_ad_token"] = azure_ad_token - if client is None: - azure_client = AsyncAzureOpenAI(**azure_client_params) - else: - azure_client = client - ## LOGGING - logging_obj.pre_call( - input=data["messages"], - api_key=azure_client.api_key, - additional_args={ - "headers": {"Authorization": f"Bearer {azure_client.api_key}"}, - "api_base": azure_client._base_url._uri_reference, - "acompletion": True, - "complete_input_dict": data, - }, - ) - response = await azure_client.chat.completions.create(**data) - # return response - streamwrapper = CustomStreamWrapper( - completion_stream=response, - model=model, - custom_llm_provider="azure", - logging_obj=logging_obj, - ) - return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails + try: + # init AzureOpenAI Client + azure_client_params = { + "api_version": api_version, + "azure_endpoint": api_base, + "azure_deployment": model, + "http_client": litellm.client_session, + "max_retries": data.pop("max_retries", 2), + "timeout": timeout, + } + if api_key is not None: + azure_client_params["api_key"] = api_key + elif azure_ad_token is not None: + azure_client_params["azure_ad_token"] = azure_ad_token + if client is None: + azure_client = AsyncAzureOpenAI(**azure_client_params) + else: + azure_client = client + ## LOGGING + logging_obj.pre_call( + input=data["messages"], + api_key=azure_client.api_key, + additional_args={ + "headers": {"Authorization": f"Bearer {azure_client.api_key}"}, + "api_base": azure_client._base_url._uri_reference, + "acompletion": True, + "complete_input_dict": data, + }, + ) + response = await azure_client.chat.completions.create(**data) + # return response + streamwrapper = CustomStreamWrapper( + completion_stream=response, + model=model, + custom_llm_provider="azure", + logging_obj=logging_obj, + ) + return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails + except Exception as e: + if hasattr(e, "status_code"): + raise AzureOpenAIError(status_code=e.status_code, message=str(e)) + else: + raise AzureOpenAIError(status_code=500, message=str(e)) async def aembedding( self, diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py index 6bd2656e2..8f28e86d9 100644 --- a/litellm/tests/test_custom_callback_input.py +++ b/litellm/tests/test_custom_callback_input.py @@ -290,6 +290,7 @@ class CompletionCustomHandler( kwargs["original_response"], (str, litellm.CustomStreamWrapper) ) or inspect.isasyncgen(kwargs["original_response"]) + or inspect.iscoroutine(kwargs["original_response"]) or kwargs["original_response"] == None ) assert isinstance(kwargs["additional_args"], (dict, type(None))) @@ -439,7 +440,7 @@ async def test_async_chat_azure_stream(): ) async for chunk in response: continue - ## test failure callback + # test failure callback try: response = await litellm.acompletion( model="azure/chatgpt-v-2", @@ -459,7 +460,7 @@ async def test_async_chat_azure_stream(): pytest.fail(f"An exception occurred: {str(e)}") -# asyncio.run(test_async_chat_azure_stream()) +asyncio.run(test_async_chat_azure_stream()) ## Test Bedrock + sync