fix(azure.py): raise streaming exceptions

This commit is contained in:
Krrish Dholakia 2023-12-27 15:43:01 +05:30
parent f4fe2575cc
commit 31148922b3
2 changed files with 46 additions and 39 deletions

View file

@ -398,43 +398,49 @@ class AzureChatCompletion(BaseLLM):
azure_ad_token: Optional[str] = None, azure_ad_token: Optional[str] = None,
client=None, client=None,
): ):
# init AzureOpenAI Client try:
azure_client_params = { # init AzureOpenAI Client
"api_version": api_version, azure_client_params = {
"azure_endpoint": api_base, "api_version": api_version,
"azure_deployment": model, "azure_endpoint": api_base,
"http_client": litellm.client_session, "azure_deployment": model,
"max_retries": data.pop("max_retries", 2), "http_client": litellm.client_session,
"timeout": timeout, "max_retries": data.pop("max_retries", 2),
} "timeout": timeout,
if api_key is not None: }
azure_client_params["api_key"] = api_key if api_key is not None:
elif azure_ad_token is not None: azure_client_params["api_key"] = api_key
azure_client_params["azure_ad_token"] = azure_ad_token elif azure_ad_token is not None:
if client is None: azure_client_params["azure_ad_token"] = azure_ad_token
azure_client = AsyncAzureOpenAI(**azure_client_params) if client is None:
else: azure_client = AsyncAzureOpenAI(**azure_client_params)
azure_client = client else:
## LOGGING azure_client = client
logging_obj.pre_call( ## LOGGING
input=data["messages"], logging_obj.pre_call(
api_key=azure_client.api_key, input=data["messages"],
additional_args={ api_key=azure_client.api_key,
"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, additional_args={
"api_base": azure_client._base_url._uri_reference, "headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"acompletion": True, "api_base": azure_client._base_url._uri_reference,
"complete_input_dict": data, "acompletion": True,
}, "complete_input_dict": data,
) },
response = await azure_client.chat.completions.create(**data) )
# return response response = await azure_client.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper( # return response
completion_stream=response, streamwrapper = CustomStreamWrapper(
model=model, completion_stream=response,
custom_llm_provider="azure", model=model,
logging_obj=logging_obj, custom_llm_provider="azure",
) logging_obj=logging_obj,
return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails )
return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
except Exception as e:
if hasattr(e, "status_code"):
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
else:
raise AzureOpenAIError(status_code=500, message=str(e))
async def aembedding( async def aembedding(
self, self,

View file

@ -290,6 +290,7 @@ class CompletionCustomHandler(
kwargs["original_response"], (str, litellm.CustomStreamWrapper) kwargs["original_response"], (str, litellm.CustomStreamWrapper)
) )
or inspect.isasyncgen(kwargs["original_response"]) or inspect.isasyncgen(kwargs["original_response"])
or inspect.iscoroutine(kwargs["original_response"])
or kwargs["original_response"] == None or kwargs["original_response"] == None
) )
assert isinstance(kwargs["additional_args"], (dict, type(None))) assert isinstance(kwargs["additional_args"], (dict, type(None)))
@ -439,7 +440,7 @@ async def test_async_chat_azure_stream():
) )
async for chunk in response: async for chunk in response:
continue continue
## test failure callback # test failure callback
try: try:
response = await litellm.acompletion( response = await litellm.acompletion(
model="azure/chatgpt-v-2", model="azure/chatgpt-v-2",
@ -459,7 +460,7 @@ async def test_async_chat_azure_stream():
pytest.fail(f"An exception occurred: {str(e)}") pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_chat_azure_stream()) asyncio.run(test_async_chat_azure_stream())
## Test Bedrock + sync ## Test Bedrock + sync