fix(azure.py): add response header coverage for azure models

This commit is contained in:
Krrish Dholakia 2024-08-24 15:12:51 -07:00
parent 87549a2391
commit 756a828c15
4 changed files with 62 additions and 42 deletions

View file

@ -75,9 +75,11 @@ class AzureOpenAIError(Exception):
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
headers: Optional[httpx.Headers] = None,
):
self.status_code = status_code
self.message = message
self.headers = headers
if request:
self.request = request
else:
@ -593,7 +595,6 @@ class AzureChatCompletion(BaseLLM):
client=None,
):
super().completion()
exception_mapping_worked = False
try:
if model is None or messages is None:
raise AzureOpenAIError(
@ -755,13 +756,13 @@ class AzureChatCompletion(BaseLLM):
convert_tool_call_to_json_mode=json_mode,
)
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if hasattr(e, "status_code"):
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
else:
raise AzureOpenAIError(status_code=500, message=str(e))
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
raise AzureOpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
async def acompletion(
self,
@ -1005,10 +1006,11 @@ class AzureChatCompletion(BaseLLM):
)
return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
except Exception as e:
if hasattr(e, "status_code"):
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
else:
raise AzureOpenAIError(status_code=500, message=str(e))
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
raise AzureOpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
async def aembedding(
self,
@ -1027,7 +1029,9 @@ class AzureChatCompletion(BaseLLM):
openai_aclient = AsyncAzureOpenAI(**azure_client_params)
else:
openai_aclient = client
response = await openai_aclient.embeddings.create(**data, timeout=timeout)
response = await openai_aclient.embeddings.with_raw_response.create(
**data, timeout=timeout
)
stringified_response = response.model_dump()
## LOGGING
logging_obj.post_call(
@ -1067,7 +1071,6 @@ class AzureChatCompletion(BaseLLM):
aembedding=None,
):
super().embedding()
exception_mapping_worked = False
if self._client_session is None:
self._client_session = self.create_client_session()
try:
@ -1127,7 +1130,7 @@ class AzureChatCompletion(BaseLLM):
else:
azure_client = client
## COMPLETION CALL
response = azure_client.embeddings.create(**data, timeout=timeout) # type: ignore
response = azure_client.embeddings.with_raw_response.create(**data, timeout=timeout) # type: ignore
## LOGGING
logging_obj.post_call(
input=input,
@ -1138,13 +1141,13 @@ class AzureChatCompletion(BaseLLM):
return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="embedding") # type: ignore
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if hasattr(e, "status_code"):
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
else:
raise AzureOpenAIError(status_code=500, message=str(e))
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
raise AzureOpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
async def make_async_azure_httpx_request(
self,

View file

@ -33,9 +33,11 @@ class AzureOpenAIError(Exception):
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
headers: Optional[httpx.Headers] = None,
):
self.status_code = status_code
self.message = message
self.headers = headers
if request:
self.request = request
else:
@ -311,13 +313,13 @@ class AzureTextCompletion(BaseLLM):
)
)
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if hasattr(e, "status_code"):
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
else:
raise AzureOpenAIError(status_code=500, message=str(e))
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
raise AzureOpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
async def acompletion(
self,
@ -387,10 +389,11 @@ class AzureTextCompletion(BaseLLM):
exception_mapping_worked = True
raise e
except Exception as e:
if hasattr(e, "status_code"):
raise e
else:
raise AzureOpenAIError(status_code=500, message=str(e))
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
raise AzureOpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
def streaming(
self,
@ -443,7 +446,9 @@ class AzureTextCompletion(BaseLLM):
"complete_input_dict": data,
},
)
response = azure_client.completions.create(**data, timeout=timeout)
response = azure_client.completions.with_raw_response.create(
**data, timeout=timeout
)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
@ -501,7 +506,9 @@ class AzureTextCompletion(BaseLLM):
"complete_input_dict": data,
},
)
response = await azure_client.completions.create(**data, timeout=timeout)
response = await azure_client.completions.with_raw_response.create(
**data, timeout=timeout
)
# return response
streamwrapper = CustomStreamWrapper(
completion_stream=response,
@ -511,7 +518,8 @@ class AzureTextCompletion(BaseLLM):
)
return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
except Exception as e:
if hasattr(e, "status_code"):
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
else:
raise AzureOpenAIError(status_code=500, message=str(e))
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
raise AzureOpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)

View file

@ -887,16 +887,19 @@ def _pre_call_utils(
[True, False],
)
@pytest.mark.parametrize(
"model, call_type, streaming",
"provider, model, call_type, streaming",
[
("text-embedding-ada-002", "embedding", None),
("gpt-3.5-turbo", "chat_completion", False),
("gpt-3.5-turbo", "chat_completion", True),
("gpt-3.5-turbo-instruct", "completion", True),
("openai", "text-embedding-ada-002", "embedding", None),
("openai", "gpt-3.5-turbo", "chat_completion", False),
("openai", "gpt-3.5-turbo", "chat_completion", True),
("openai", "gpt-3.5-turbo-instruct", "completion", True),
("azure", "azure/chatgpt-v-2", "chat_completion", True),
("azure", "azure/text-embedding-ada-002", "embedding", True),
("azure", "azure_text/gpt-3.5-turbo-instruct", "completion", True),
],
)
@pytest.mark.asyncio
async def test_exception_with_headers(sync_mode, model, call_type, streaming):
async def test_exception_with_headers(sync_mode, provider, model, call_type, streaming):
"""
User feedback: litellm says "No deployments available for selected model, Try again in 60 seconds"
but Azure says to retry in at most 9s
@ -908,9 +911,15 @@ async def test_exception_with_headers(sync_mode, model, call_type, streaming):
import openai
if sync_mode:
openai_client = openai.OpenAI(api_key="")
if provider == "openai":
openai_client = openai.OpenAI(api_key="")
elif provider == "azure":
openai_client = openai.AzureOpenAI(api_key="", base_url="")
else:
openai_client = openai.AsyncOpenAI(api_key="")
if provider == "openai":
openai_client = openai.AsyncOpenAI(api_key="")
elif provider == "azure":
openai_client = openai.AsyncAzureOpenAI(api_key="", base_url="")
data = {"model": model}
data, original_function, mapped_target = _pre_call_utils(

View file

@ -8157,7 +8157,7 @@ def exception_type(
model=model,
request=original_exception.request,
)
elif custom_llm_provider == "azure":
elif custom_llm_provider == "azure" or custom_llm_provider == "azure_text":
message = get_error_message(error_obj=original_exception)
if message is None:
if hasattr(original_exception, "message"):