fix(azure.py): raise streaming exceptions

This commit is contained in:
Krrish Dholakia 2023-12-27 15:43:01 +05:30
parent 1100993834
commit db6ef70a68
2 changed files with 46 additions and 39 deletions

View file

@ -398,43 +398,49 @@ class AzureChatCompletion(BaseLLM):
azure_ad_token: Optional[str] = None,
client=None,
):
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": data.pop("max_retries", 2),
"timeout": timeout,
}
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token
if client is None:
azure_client = AsyncAzureOpenAI(**azure_client_params)
else:
azure_client = client
## LOGGING
logging_obj.pre_call(
input=data["messages"],
api_key=azure_client.api_key,
additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
)
response = await azure_client.chat.completions.create(**data)
# return response
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="azure",
logging_obj=logging_obj,
)
return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
try:
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": data.pop("max_retries", 2),
"timeout": timeout,
}
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token
if client is None:
azure_client = AsyncAzureOpenAI(**azure_client_params)
else:
azure_client = client
## LOGGING
logging_obj.pre_call(
input=data["messages"],
api_key=azure_client.api_key,
additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
)
response = await azure_client.chat.completions.create(**data)
# return response
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="azure",
logging_obj=logging_obj,
)
return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
except Exception as e:
if hasattr(e, "status_code"):
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
else:
raise AzureOpenAIError(status_code=500, message=str(e))
async def aembedding(
self,