forked from phoenix/litellm-mirror
fix(azure.py): raise streaming exceptions
This commit is contained in:
parent
f4fe2575cc
commit
31148922b3
2 changed files with 46 additions and 39 deletions
|
@ -398,43 +398,49 @@ class AzureChatCompletion(BaseLLM):
|
||||||
azure_ad_token: Optional[str] = None,
|
azure_ad_token: Optional[str] = None,
|
||||||
client=None,
|
client=None,
|
||||||
):
|
):
|
||||||
# init AzureOpenAI Client
|
try:
|
||||||
azure_client_params = {
|
# init AzureOpenAI Client
|
||||||
"api_version": api_version,
|
azure_client_params = {
|
||||||
"azure_endpoint": api_base,
|
"api_version": api_version,
|
||||||
"azure_deployment": model,
|
"azure_endpoint": api_base,
|
||||||
"http_client": litellm.client_session,
|
"azure_deployment": model,
|
||||||
"max_retries": data.pop("max_retries", 2),
|
"http_client": litellm.client_session,
|
||||||
"timeout": timeout,
|
"max_retries": data.pop("max_retries", 2),
|
||||||
}
|
"timeout": timeout,
|
||||||
if api_key is not None:
|
}
|
||||||
azure_client_params["api_key"] = api_key
|
if api_key is not None:
|
||||||
elif azure_ad_token is not None:
|
azure_client_params["api_key"] = api_key
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
elif azure_ad_token is not None:
|
||||||
if client is None:
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
if client is None:
|
||||||
else:
|
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
||||||
azure_client = client
|
else:
|
||||||
## LOGGING
|
azure_client = client
|
||||||
logging_obj.pre_call(
|
## LOGGING
|
||||||
input=data["messages"],
|
logging_obj.pre_call(
|
||||||
api_key=azure_client.api_key,
|
input=data["messages"],
|
||||||
additional_args={
|
api_key=azure_client.api_key,
|
||||||
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
|
additional_args={
|
||||||
"api_base": azure_client._base_url._uri_reference,
|
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
|
||||||
"acompletion": True,
|
"api_base": azure_client._base_url._uri_reference,
|
||||||
"complete_input_dict": data,
|
"acompletion": True,
|
||||||
},
|
"complete_input_dict": data,
|
||||||
)
|
},
|
||||||
response = await azure_client.chat.completions.create(**data)
|
)
|
||||||
# return response
|
response = await azure_client.chat.completions.create(**data)
|
||||||
streamwrapper = CustomStreamWrapper(
|
# return response
|
||||||
completion_stream=response,
|
streamwrapper = CustomStreamWrapper(
|
||||||
model=model,
|
completion_stream=response,
|
||||||
custom_llm_provider="azure",
|
model=model,
|
||||||
logging_obj=logging_obj,
|
custom_llm_provider="azure",
|
||||||
)
|
logging_obj=logging_obj,
|
||||||
return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
|
)
|
||||||
|
return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
|
||||||
|
except Exception as e:
|
||||||
|
if hasattr(e, "status_code"):
|
||||||
|
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
|
||||||
|
else:
|
||||||
|
raise AzureOpenAIError(status_code=500, message=str(e))
|
||||||
|
|
||||||
async def aembedding(
|
async def aembedding(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -290,6 +290,7 @@ class CompletionCustomHandler(
|
||||||
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
|
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
|
||||||
)
|
)
|
||||||
or inspect.isasyncgen(kwargs["original_response"])
|
or inspect.isasyncgen(kwargs["original_response"])
|
||||||
|
or inspect.iscoroutine(kwargs["original_response"])
|
||||||
or kwargs["original_response"] == None
|
or kwargs["original_response"] == None
|
||||||
)
|
)
|
||||||
assert isinstance(kwargs["additional_args"], (dict, type(None)))
|
assert isinstance(kwargs["additional_args"], (dict, type(None)))
|
||||||
|
@ -439,7 +440,7 @@ async def test_async_chat_azure_stream():
|
||||||
)
|
)
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
continue
|
continue
|
||||||
## test failure callback
|
# test failure callback
|
||||||
try:
|
try:
|
||||||
response = await litellm.acompletion(
|
response = await litellm.acompletion(
|
||||||
model="azure/chatgpt-v-2",
|
model="azure/chatgpt-v-2",
|
||||||
|
@ -459,7 +460,7 @@ async def test_async_chat_azure_stream():
|
||||||
pytest.fail(f"An exception occurred: {str(e)}")
|
pytest.fail(f"An exception occurred: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
# asyncio.run(test_async_chat_azure_stream())
|
asyncio.run(test_async_chat_azure_stream())
|
||||||
|
|
||||||
|
|
||||||
## Test Bedrock + sync
|
## Test Bedrock + sync
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue