fix(azure.py): add response header coverage for azure models

This commit is contained in:
Krrish Dholakia 2024-08-24 15:12:51 -07:00
parent 87549a2391
commit 756a828c15
4 changed files with 62 additions and 42 deletions

View file

@ -75,9 +75,11 @@ class AzureOpenAIError(Exception):
message, message,
request: Optional[httpx.Request] = None, request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None, response: Optional[httpx.Response] = None,
headers: Optional[httpx.Headers] = None,
): ):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.headers = headers
if request: if request:
self.request = request self.request = request
else: else:
@ -593,7 +595,6 @@ class AzureChatCompletion(BaseLLM):
client=None, client=None,
): ):
super().completion() super().completion()
exception_mapping_worked = False
try: try:
if model is None or messages is None: if model is None or messages is None:
raise AzureOpenAIError( raise AzureOpenAIError(
@ -755,13 +756,13 @@ class AzureChatCompletion(BaseLLM):
convert_tool_call_to_json_mode=json_mode, convert_tool_call_to_json_mode=json_mode,
) )
except AzureOpenAIError as e: except AzureOpenAIError as e:
exception_mapping_worked = True
raise e raise e
except Exception as e: except Exception as e:
if hasattr(e, "status_code"): status_code = getattr(e, "status_code", 500)
raise AzureOpenAIError(status_code=e.status_code, message=str(e)) error_headers = getattr(e, "headers", None)
else: raise AzureOpenAIError(
raise AzureOpenAIError(status_code=500, message=str(e)) status_code=status_code, message=str(e), headers=error_headers
)
async def acompletion( async def acompletion(
self, self,
@ -1005,10 +1006,11 @@ class AzureChatCompletion(BaseLLM):
) )
return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
except Exception as e: except Exception as e:
if hasattr(e, "status_code"): status_code = getattr(e, "status_code", 500)
raise AzureOpenAIError(status_code=e.status_code, message=str(e)) error_headers = getattr(e, "headers", None)
else: raise AzureOpenAIError(
raise AzureOpenAIError(status_code=500, message=str(e)) status_code=status_code, message=str(e), headers=error_headers
)
async def aembedding( async def aembedding(
self, self,
@ -1027,7 +1029,9 @@ class AzureChatCompletion(BaseLLM):
openai_aclient = AsyncAzureOpenAI(**azure_client_params) openai_aclient = AsyncAzureOpenAI(**azure_client_params)
else: else:
openai_aclient = client openai_aclient = client
response = await openai_aclient.embeddings.create(**data, timeout=timeout) response = await openai_aclient.embeddings.with_raw_response.create(
**data, timeout=timeout
)
stringified_response = response.model_dump() stringified_response = response.model_dump()
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -1067,7 +1071,6 @@ class AzureChatCompletion(BaseLLM):
aembedding=None, aembedding=None,
): ):
super().embedding() super().embedding()
exception_mapping_worked = False
if self._client_session is None: if self._client_session is None:
self._client_session = self.create_client_session() self._client_session = self.create_client_session()
try: try:
@ -1127,7 +1130,7 @@ class AzureChatCompletion(BaseLLM):
else: else:
azure_client = client azure_client = client
## COMPLETION CALL ## COMPLETION CALL
response = azure_client.embeddings.create(**data, timeout=timeout) # type: ignore response = azure_client.embeddings.with_raw_response.create(**data, timeout=timeout) # type: ignore
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=input, input=input,
@ -1138,13 +1141,13 @@ class AzureChatCompletion(BaseLLM):
return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="embedding") # type: ignore return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="embedding") # type: ignore
except AzureOpenAIError as e: except AzureOpenAIError as e:
exception_mapping_worked = True
raise e raise e
except Exception as e: except Exception as e:
if hasattr(e, "status_code"): status_code = getattr(e, "status_code", 500)
raise AzureOpenAIError(status_code=e.status_code, message=str(e)) error_headers = getattr(e, "headers", None)
else: raise AzureOpenAIError(
raise AzureOpenAIError(status_code=500, message=str(e)) status_code=status_code, message=str(e), headers=error_headers
)
async def make_async_azure_httpx_request( async def make_async_azure_httpx_request(
self, self,

View file

@ -33,9 +33,11 @@ class AzureOpenAIError(Exception):
message, message,
request: Optional[httpx.Request] = None, request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None, response: Optional[httpx.Response] = None,
headers: Optional[httpx.Headers] = None,
): ):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.headers = headers
if request: if request:
self.request = request self.request = request
else: else:
@ -311,13 +313,13 @@ class AzureTextCompletion(BaseLLM):
) )
) )
except AzureOpenAIError as e: except AzureOpenAIError as e:
exception_mapping_worked = True
raise e raise e
except Exception as e: except Exception as e:
if hasattr(e, "status_code"): status_code = getattr(e, "status_code", 500)
raise AzureOpenAIError(status_code=e.status_code, message=str(e)) error_headers = getattr(e, "headers", None)
else: raise AzureOpenAIError(
raise AzureOpenAIError(status_code=500, message=str(e)) status_code=status_code, message=str(e), headers=error_headers
)
async def acompletion( async def acompletion(
self, self,
@ -387,10 +389,11 @@ class AzureTextCompletion(BaseLLM):
exception_mapping_worked = True exception_mapping_worked = True
raise e raise e
except Exception as e: except Exception as e:
if hasattr(e, "status_code"): status_code = getattr(e, "status_code", 500)
raise e error_headers = getattr(e, "headers", None)
else: raise AzureOpenAIError(
raise AzureOpenAIError(status_code=500, message=str(e)) status_code=status_code, message=str(e), headers=error_headers
)
def streaming( def streaming(
self, self,
@ -443,7 +446,9 @@ class AzureTextCompletion(BaseLLM):
"complete_input_dict": data, "complete_input_dict": data,
}, },
) )
response = azure_client.completions.create(**data, timeout=timeout) response = azure_client.completions.with_raw_response.create(
**data, timeout=timeout
)
streamwrapper = CustomStreamWrapper( streamwrapper = CustomStreamWrapper(
completion_stream=response, completion_stream=response,
model=model, model=model,
@ -501,7 +506,9 @@ class AzureTextCompletion(BaseLLM):
"complete_input_dict": data, "complete_input_dict": data,
}, },
) )
response = await azure_client.completions.create(**data, timeout=timeout) response = await azure_client.completions.with_raw_response.create(
**data, timeout=timeout
)
# return response # return response
streamwrapper = CustomStreamWrapper( streamwrapper = CustomStreamWrapper(
completion_stream=response, completion_stream=response,
@ -511,7 +518,8 @@ class AzureTextCompletion(BaseLLM):
) )
return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
except Exception as e: except Exception as e:
if hasattr(e, "status_code"): status_code = getattr(e, "status_code", 500)
raise AzureOpenAIError(status_code=e.status_code, message=str(e)) error_headers = getattr(e, "headers", None)
else: raise AzureOpenAIError(
raise AzureOpenAIError(status_code=500, message=str(e)) status_code=status_code, message=str(e), headers=error_headers
)

View file

@ -887,16 +887,19 @@ def _pre_call_utils(
[True, False], [True, False],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model, call_type, streaming", "provider, model, call_type, streaming",
[ [
("text-embedding-ada-002", "embedding", None), ("openai", "text-embedding-ada-002", "embedding", None),
("gpt-3.5-turbo", "chat_completion", False), ("openai", "gpt-3.5-turbo", "chat_completion", False),
("gpt-3.5-turbo", "chat_completion", True), ("openai", "gpt-3.5-turbo", "chat_completion", True),
("gpt-3.5-turbo-instruct", "completion", True), ("openai", "gpt-3.5-turbo-instruct", "completion", True),
("azure", "azure/chatgpt-v-2", "chat_completion", True),
("azure", "azure/text-embedding-ada-002", "embedding", True),
("azure", "azure_text/gpt-3.5-turbo-instruct", "completion", True),
], ],
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_exception_with_headers(sync_mode, model, call_type, streaming): async def test_exception_with_headers(sync_mode, provider, model, call_type, streaming):
""" """
User feedback: litellm says "No deployments available for selected model, Try again in 60 seconds" User feedback: litellm says "No deployments available for selected model, Try again in 60 seconds"
but Azure says to retry in at most 9s but Azure says to retry in at most 9s
@ -908,9 +911,15 @@ async def test_exception_with_headers(sync_mode, model, call_type, streaming):
import openai import openai
if sync_mode: if sync_mode:
if provider == "openai":
openai_client = openai.OpenAI(api_key="") openai_client = openai.OpenAI(api_key="")
elif provider == "azure":
openai_client = openai.AzureOpenAI(api_key="", base_url="")
else: else:
if provider == "openai":
openai_client = openai.AsyncOpenAI(api_key="") openai_client = openai.AsyncOpenAI(api_key="")
elif provider == "azure":
openai_client = openai.AsyncAzureOpenAI(api_key="", base_url="")
data = {"model": model} data = {"model": model}
data, original_function, mapped_target = _pre_call_utils( data, original_function, mapped_target = _pre_call_utils(

View file

@ -8157,7 +8157,7 @@ def exception_type(
model=model, model=model,
request=original_exception.request, request=original_exception.request,
) )
elif custom_llm_provider == "azure": elif custom_llm_provider == "azure" or custom_llm_provider == "azure_text":
message = get_error_message(error_obj=original_exception) message = get_error_message(error_obj=original_exception)
if message is None: if message is None:
if hasattr(original_exception, "message"): if hasattr(original_exception, "message"):