mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(azure.py): add response header coverage for azure models
This commit is contained in:
parent
87549a2391
commit
756a828c15
4 changed files with 62 additions and 42 deletions
|
@ -75,9 +75,11 @@ class AzureOpenAIError(Exception):
|
||||||
message,
|
message,
|
||||||
request: Optional[httpx.Request] = None,
|
request: Optional[httpx.Request] = None,
|
||||||
response: Optional[httpx.Response] = None,
|
response: Optional[httpx.Response] = None,
|
||||||
|
headers: Optional[httpx.Headers] = None,
|
||||||
):
|
):
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
self.message = message
|
self.message = message
|
||||||
|
self.headers = headers
|
||||||
if request:
|
if request:
|
||||||
self.request = request
|
self.request = request
|
||||||
else:
|
else:
|
||||||
|
@ -593,7 +595,6 @@ class AzureChatCompletion(BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
):
|
):
|
||||||
super().completion()
|
super().completion()
|
||||||
exception_mapping_worked = False
|
|
||||||
try:
|
try:
|
||||||
if model is None or messages is None:
|
if model is None or messages is None:
|
||||||
raise AzureOpenAIError(
|
raise AzureOpenAIError(
|
||||||
|
@ -755,13 +756,13 @@ class AzureChatCompletion(BaseLLM):
|
||||||
convert_tool_call_to_json_mode=json_mode,
|
convert_tool_call_to_json_mode=json_mode,
|
||||||
)
|
)
|
||||||
except AzureOpenAIError as e:
|
except AzureOpenAIError as e:
|
||||||
exception_mapping_worked = True
|
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if hasattr(e, "status_code"):
|
status_code = getattr(e, "status_code", 500)
|
||||||
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
|
error_headers = getattr(e, "headers", None)
|
||||||
else:
|
raise AzureOpenAIError(
|
||||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
status_code=status_code, message=str(e), headers=error_headers
|
||||||
|
)
|
||||||
|
|
||||||
async def acompletion(
|
async def acompletion(
|
||||||
self,
|
self,
|
||||||
|
@ -1005,10 +1006,11 @@ class AzureChatCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
|
return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if hasattr(e, "status_code"):
|
status_code = getattr(e, "status_code", 500)
|
||||||
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
|
error_headers = getattr(e, "headers", None)
|
||||||
else:
|
raise AzureOpenAIError(
|
||||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
status_code=status_code, message=str(e), headers=error_headers
|
||||||
|
)
|
||||||
|
|
||||||
async def aembedding(
|
async def aembedding(
|
||||||
self,
|
self,
|
||||||
|
@ -1027,7 +1029,9 @@ class AzureChatCompletion(BaseLLM):
|
||||||
openai_aclient = AsyncAzureOpenAI(**azure_client_params)
|
openai_aclient = AsyncAzureOpenAI(**azure_client_params)
|
||||||
else:
|
else:
|
||||||
openai_aclient = client
|
openai_aclient = client
|
||||||
response = await openai_aclient.embeddings.create(**data, timeout=timeout)
|
response = await openai_aclient.embeddings.with_raw_response.create(
|
||||||
|
**data, timeout=timeout
|
||||||
|
)
|
||||||
stringified_response = response.model_dump()
|
stringified_response = response.model_dump()
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
@ -1067,7 +1071,6 @@ class AzureChatCompletion(BaseLLM):
|
||||||
aembedding=None,
|
aembedding=None,
|
||||||
):
|
):
|
||||||
super().embedding()
|
super().embedding()
|
||||||
exception_mapping_worked = False
|
|
||||||
if self._client_session is None:
|
if self._client_session is None:
|
||||||
self._client_session = self.create_client_session()
|
self._client_session = self.create_client_session()
|
||||||
try:
|
try:
|
||||||
|
@ -1127,7 +1130,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
else:
|
else:
|
||||||
azure_client = client
|
azure_client = client
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
response = azure_client.embeddings.create(**data, timeout=timeout) # type: ignore
|
response = azure_client.embeddings.with_raw_response.create(**data, timeout=timeout) # type: ignore
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=input,
|
input=input,
|
||||||
|
@ -1138,13 +1141,13 @@ class AzureChatCompletion(BaseLLM):
|
||||||
|
|
||||||
return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="embedding") # type: ignore
|
return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="embedding") # type: ignore
|
||||||
except AzureOpenAIError as e:
|
except AzureOpenAIError as e:
|
||||||
exception_mapping_worked = True
|
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if hasattr(e, "status_code"):
|
status_code = getattr(e, "status_code", 500)
|
||||||
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
|
error_headers = getattr(e, "headers", None)
|
||||||
else:
|
raise AzureOpenAIError(
|
||||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
status_code=status_code, message=str(e), headers=error_headers
|
||||||
|
)
|
||||||
|
|
||||||
async def make_async_azure_httpx_request(
|
async def make_async_azure_httpx_request(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -33,9 +33,11 @@ class AzureOpenAIError(Exception):
|
||||||
message,
|
message,
|
||||||
request: Optional[httpx.Request] = None,
|
request: Optional[httpx.Request] = None,
|
||||||
response: Optional[httpx.Response] = None,
|
response: Optional[httpx.Response] = None,
|
||||||
|
headers: Optional[httpx.Headers] = None,
|
||||||
):
|
):
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
self.message = message
|
self.message = message
|
||||||
|
self.headers = headers
|
||||||
if request:
|
if request:
|
||||||
self.request = request
|
self.request = request
|
||||||
else:
|
else:
|
||||||
|
@ -311,13 +313,13 @@ class AzureTextCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except AzureOpenAIError as e:
|
except AzureOpenAIError as e:
|
||||||
exception_mapping_worked = True
|
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if hasattr(e, "status_code"):
|
status_code = getattr(e, "status_code", 500)
|
||||||
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
|
error_headers = getattr(e, "headers", None)
|
||||||
else:
|
raise AzureOpenAIError(
|
||||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
status_code=status_code, message=str(e), headers=error_headers
|
||||||
|
)
|
||||||
|
|
||||||
async def acompletion(
|
async def acompletion(
|
||||||
self,
|
self,
|
||||||
|
@ -387,10 +389,11 @@ class AzureTextCompletion(BaseLLM):
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if hasattr(e, "status_code"):
|
status_code = getattr(e, "status_code", 500)
|
||||||
raise e
|
error_headers = getattr(e, "headers", None)
|
||||||
else:
|
raise AzureOpenAIError(
|
||||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
status_code=status_code, message=str(e), headers=error_headers
|
||||||
|
)
|
||||||
|
|
||||||
def streaming(
|
def streaming(
|
||||||
self,
|
self,
|
||||||
|
@ -443,7 +446,9 @@ class AzureTextCompletion(BaseLLM):
|
||||||
"complete_input_dict": data,
|
"complete_input_dict": data,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
response = azure_client.completions.create(**data, timeout=timeout)
|
response = azure_client.completions.with_raw_response.create(
|
||||||
|
**data, timeout=timeout
|
||||||
|
)
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
completion_stream=response,
|
completion_stream=response,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -501,7 +506,9 @@ class AzureTextCompletion(BaseLLM):
|
||||||
"complete_input_dict": data,
|
"complete_input_dict": data,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
response = await azure_client.completions.create(**data, timeout=timeout)
|
response = await azure_client.completions.with_raw_response.create(
|
||||||
|
**data, timeout=timeout
|
||||||
|
)
|
||||||
# return response
|
# return response
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
completion_stream=response,
|
completion_stream=response,
|
||||||
|
@ -511,7 +518,8 @@ class AzureTextCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
|
return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if hasattr(e, "status_code"):
|
status_code = getattr(e, "status_code", 500)
|
||||||
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
|
error_headers = getattr(e, "headers", None)
|
||||||
else:
|
raise AzureOpenAIError(
|
||||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
status_code=status_code, message=str(e), headers=error_headers
|
||||||
|
)
|
||||||
|
|
|
@ -887,16 +887,19 @@ def _pre_call_utils(
|
||||||
[True, False],
|
[True, False],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model, call_type, streaming",
|
"provider, model, call_type, streaming",
|
||||||
[
|
[
|
||||||
("text-embedding-ada-002", "embedding", None),
|
("openai", "text-embedding-ada-002", "embedding", None),
|
||||||
("gpt-3.5-turbo", "chat_completion", False),
|
("openai", "gpt-3.5-turbo", "chat_completion", False),
|
||||||
("gpt-3.5-turbo", "chat_completion", True),
|
("openai", "gpt-3.5-turbo", "chat_completion", True),
|
||||||
("gpt-3.5-turbo-instruct", "completion", True),
|
("openai", "gpt-3.5-turbo-instruct", "completion", True),
|
||||||
|
("azure", "azure/chatgpt-v-2", "chat_completion", True),
|
||||||
|
("azure", "azure/text-embedding-ada-002", "embedding", True),
|
||||||
|
("azure", "azure_text/gpt-3.5-turbo-instruct", "completion", True),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_exception_with_headers(sync_mode, model, call_type, streaming):
|
async def test_exception_with_headers(sync_mode, provider, model, call_type, streaming):
|
||||||
"""
|
"""
|
||||||
User feedback: litellm says "No deployments available for selected model, Try again in 60 seconds"
|
User feedback: litellm says "No deployments available for selected model, Try again in 60 seconds"
|
||||||
but Azure says to retry in at most 9s
|
but Azure says to retry in at most 9s
|
||||||
|
@ -908,9 +911,15 @@ async def test_exception_with_headers(sync_mode, model, call_type, streaming):
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
if sync_mode:
|
if sync_mode:
|
||||||
openai_client = openai.OpenAI(api_key="")
|
if provider == "openai":
|
||||||
|
openai_client = openai.OpenAI(api_key="")
|
||||||
|
elif provider == "azure":
|
||||||
|
openai_client = openai.AzureOpenAI(api_key="", base_url="")
|
||||||
else:
|
else:
|
||||||
openai_client = openai.AsyncOpenAI(api_key="")
|
if provider == "openai":
|
||||||
|
openai_client = openai.AsyncOpenAI(api_key="")
|
||||||
|
elif provider == "azure":
|
||||||
|
openai_client = openai.AsyncAzureOpenAI(api_key="", base_url="")
|
||||||
|
|
||||||
data = {"model": model}
|
data = {"model": model}
|
||||||
data, original_function, mapped_target = _pre_call_utils(
|
data, original_function, mapped_target = _pre_call_utils(
|
||||||
|
|
|
@ -8157,7 +8157,7 @@ def exception_type(
|
||||||
model=model,
|
model=model,
|
||||||
request=original_exception.request,
|
request=original_exception.request,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "azure":
|
elif custom_llm_provider == "azure" or custom_llm_provider == "azure_text":
|
||||||
message = get_error_message(error_obj=original_exception)
|
message = get_error_message(error_obj=original_exception)
|
||||||
if message is None:
|
if message is None:
|
||||||
if hasattr(original_exception, "message"):
|
if hasattr(original_exception, "message"):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue