fix(azure.py): add response header coverage for azure models

This commit is contained in:
Krrish Dholakia 2024-08-24 15:12:51 -07:00
parent 87549a2391
commit 756a828c15
4 changed files with 62 additions and 42 deletions

View file

@ -75,9 +75,11 @@ class AzureOpenAIError(Exception):
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
headers: Optional[httpx.Headers] = None,
):
self.status_code = status_code
self.message = message
self.headers = headers
if request:
self.request = request
else:
@ -593,7 +595,6 @@ class AzureChatCompletion(BaseLLM):
client=None,
):
super().completion()
exception_mapping_worked = False
try:
if model is None or messages is None:
raise AzureOpenAIError(
@ -755,13 +756,13 @@ class AzureChatCompletion(BaseLLM):
convert_tool_call_to_json_mode=json_mode,
)
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if hasattr(e, "status_code"):
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
else:
raise AzureOpenAIError(status_code=500, message=str(e))
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
raise AzureOpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
async def acompletion(
self,
@ -1005,10 +1006,11 @@ class AzureChatCompletion(BaseLLM):
)
return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
except Exception as e:
if hasattr(e, "status_code"):
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
else:
raise AzureOpenAIError(status_code=500, message=str(e))
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
raise AzureOpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
async def aembedding(
self,
@ -1027,7 +1029,9 @@ class AzureChatCompletion(BaseLLM):
openai_aclient = AsyncAzureOpenAI(**azure_client_params)
else:
openai_aclient = client
response = await openai_aclient.embeddings.create(**data, timeout=timeout)
response = await openai_aclient.embeddings.with_raw_response.create(
**data, timeout=timeout
)
stringified_response = response.model_dump()
## LOGGING
logging_obj.post_call(
@ -1067,7 +1071,6 @@ class AzureChatCompletion(BaseLLM):
aembedding=None,
):
super().embedding()
exception_mapping_worked = False
if self._client_session is None:
self._client_session = self.create_client_session()
try:
@ -1127,7 +1130,7 @@ class AzureChatCompletion(BaseLLM):
else:
azure_client = client
## COMPLETION CALL
response = azure_client.embeddings.create(**data, timeout=timeout) # type: ignore
response = azure_client.embeddings.with_raw_response.create(**data, timeout=timeout) # type: ignore
## LOGGING
logging_obj.post_call(
input=input,
@ -1138,13 +1141,13 @@ class AzureChatCompletion(BaseLLM):
return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="embedding") # type: ignore
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if hasattr(e, "status_code"):
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
else:
raise AzureOpenAIError(status_code=500, message=str(e))
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
raise AzureOpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
async def make_async_azure_httpx_request(
self,