fix(azure.py,-openai.py): correctly raise errors if streaming calls fail

This commit is contained in:
Krrish Dholakia 2023-12-27 15:08:37 +05:30
parent 9ba520cc8b
commit c9fdbaf898
6 changed files with 110 additions and 24 deletions

View file

@ -427,14 +427,14 @@ class AzureChatCompletion(BaseLLM):
},
)
response = await azure_client.chat.completions.create(**data)
# return response
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="azure",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
async def aembedding(
self,

View file

@ -482,8 +482,7 @@ class OpenAIChatCompletion(BaseLLM):
custom_llm_provider="openai",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
return streamwrapper
except (
Exception
) as e: # need to exception handle here. async exceptions don't get caught in sync functions.

View file

@ -198,18 +198,16 @@ async def acompletion(*args, **kwargs):
or custom_llm_provider == "ollama"
or custom_llm_provider == "ollama_chat"
or custom_llm_provider == "vertex_ai"
): # currently implemented aiohttp calls for just azure and openai, soon all.
if kwargs.get("stream", False):
response = completion(*args, **kwargs)
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance(
init_response, ModelResponse
): ## CACHING SCENARIO
response = init_response
elif asyncio.iscoroutine(init_response):
response = await init_response
else:
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance(
init_response, ModelResponse
): ## CACHING SCENARIO
response = init_response
elif asyncio.iscoroutine(init_response):
response = await init_response
response = init_response
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)

View file

@ -21,6 +21,13 @@ class MyCustomHandler(CustomLogger):
def log_pre_api_call(self, model, messages, kwargs):
print(f"Pre-API Call")
print(
f"previous_models: {kwargs['litellm_params']['metadata']['previous_models']}"
)
self.previous_models += len(
kwargs["litellm_params"]["metadata"]["previous_models"]
) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": <complete_traceback>}]}
print(f"self.previous_models: {self.previous_models}")
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
print(
@ -34,13 +41,6 @@ class MyCustomHandler(CustomLogger):
print(f"On Stream")
def log_success_event(self, kwargs, response_obj, start_time, end_time):
print(
f"previous_models: {kwargs['litellm_params']['metadata']['previous_models']}"
)
self.previous_models += len(
kwargs["litellm_params"]["metadata"]["previous_models"]
) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": <complete_traceback>}]}
print(f"self.previous_models: {self.previous_models}")
print(f"On Success")
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
@ -396,3 +396,89 @@ async def test_dynamic_fallbacks_async():
# asyncio.run(test_dynamic_fallbacks_async())
@pytest.mark.asyncio
async def test_async_fallbacks_streaming():
litellm.set_verbose = False
model_list = [
{ # list of model deployments
"model_name": "azure/gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800,
},
{ # list of model deployments
"model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800,
},
{
"model_name": "azure/gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-functioncalling",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800,
},
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000000,
"rpm": 9000,
},
{
"model_name": "gpt-3.5-turbo-16k", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo-16k",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000000,
"rpm": 9000,
},
]
router = Router(
model_list=model_list,
fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}],
context_window_fallbacks=[
{"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]},
{"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]},
],
set_verbose=False,
)
customHandler = MyCustomHandler()
litellm.callbacks = [customHandler]
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
try:
response = await router.acompletion(**kwargs, stream=True)
print(f"customHandler.previous_models: {customHandler.previous_models}")
await asyncio.sleep(
0.05
) # allow a delay as success_callbacks are on a separate thread
assert customHandler.previous_models == 1 # 0 retries, 1 fallback
router.reset()
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
finally:
router.reset()

View file

@ -271,7 +271,7 @@ def test_completion_azure_stream():
pytest.fail(f"Error occurred: {e}")
# test_completion_azure_stream()
test_completion_azure_stream()
def test_completion_azure_function_calling_stream():

View file

@ -6739,7 +6739,10 @@ class CustomStreamWrapper:
if str_line.choices[0].finish_reason:
is_finished = True
finish_reason = str_line.choices[0].finish_reason
if str_line.choices[0].logprobs is not None:
if (
"logprobs" in str_line.choices[0]
and str_line.choices[0].logprobs is not None
):
logprobs = str_line.choices[0].logprobs
else:
logprobs = None