diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 123caafd0..b2614faa0 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -427,14 +427,14 @@ class AzureChatCompletion(BaseLLM): }, ) response = await azure_client.chat.completions.create(**data) + # return response streamwrapper = CustomStreamWrapper( completion_stream=response, model=model, custom_llm_provider="azure", logging_obj=logging_obj, ) - async for transformed_chunk in streamwrapper: - yield transformed_chunk + return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails async def aembedding( self, diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 5ef495631..c887eb405 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -482,8 +482,7 @@ class OpenAIChatCompletion(BaseLLM): custom_llm_provider="openai", logging_obj=logging_obj, ) - async for transformed_chunk in streamwrapper: - yield transformed_chunk + return streamwrapper except ( Exception ) as e: # need to exception handle here. async exceptions don't get caught in sync functions. diff --git a/litellm/main.py b/litellm/main.py index 4421f4c0e..2f4894410 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -198,18 +198,16 @@ async def acompletion(*args, **kwargs): or custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat" or custom_llm_provider == "vertex_ai" - ): # currently implemented aiohttp calls for just azure and openai, soon all. - if kwargs.get("stream", False): - response = completion(*args, **kwargs) + ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. + init_response = await loop.run_in_executor(None, func_with_context) + if isinstance(init_response, dict) or isinstance( + init_response, ModelResponse + ): ## CACHING SCENARIO + response = init_response + elif asyncio.iscoroutine(init_response): + response = await init_response else: - # Await normally - init_response = await loop.run_in_executor(None, func_with_context) - if isinstance(init_response, dict) or isinstance( - init_response, ModelResponse - ): ## CACHING SCENARIO - response = init_response - elif asyncio.iscoroutine(init_response): - response = await init_response + response = init_response else: # Call the synchronous function using run_in_executor response = await loop.run_in_executor(None, func_with_context) diff --git a/litellm/tests/test_router_fallbacks.py b/litellm/tests/test_router_fallbacks.py index bfd3ab3c7..6f97cabec 100644 --- a/litellm/tests/test_router_fallbacks.py +++ b/litellm/tests/test_router_fallbacks.py @@ -21,6 +21,13 @@ class MyCustomHandler(CustomLogger): def log_pre_api_call(self, model, messages, kwargs): print(f"Pre-API Call") + print( + f"previous_models: {kwargs['litellm_params']['metadata']['previous_models']}" + ) + self.previous_models += len( + kwargs["litellm_params"]["metadata"]["previous_models"] + ) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": }]} + print(f"self.previous_models: {self.previous_models}") def log_post_api_call(self, kwargs, response_obj, start_time, end_time): print( @@ -34,13 +41,6 @@ class MyCustomHandler(CustomLogger): print(f"On Stream") def log_success_event(self, kwargs, response_obj, start_time, end_time): - print( - f"previous_models: {kwargs['litellm_params']['metadata']['previous_models']}" - ) - self.previous_models += len( - kwargs["litellm_params"]["metadata"]["previous_models"] - ) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": }]} - print(f"self.previous_models: {self.previous_models}") print(f"On Success") async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): @@ -396,3 +396,89 @@ async def test_dynamic_fallbacks_async(): # asyncio.run(test_dynamic_fallbacks_async()) + + +@pytest.mark.asyncio +async def test_async_fallbacks_streaming(): + litellm.set_verbose = False + model_list = [ + { # list of model deployments + "model_name": "azure/gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", + "api_key": "bad-key", + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, + { # list of model deployments + "model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, + { + "model_name": "azure/gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-functioncalling", + "api_key": "bad-key", + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, + { + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "gpt-3.5-turbo", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + "tpm": 1000000, + "rpm": 9000, + }, + { + "model_name": "gpt-3.5-turbo-16k", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "gpt-3.5-turbo-16k", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + "tpm": 1000000, + "rpm": 9000, + }, + ] + + router = Router( + model_list=model_list, + fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}], + context_window_fallbacks=[ + {"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]}, + {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}, + ], + set_verbose=False, + ) + customHandler = MyCustomHandler() + litellm.callbacks = [customHandler] + user_message = "Hello, how are you?" + messages = [{"content": user_message, "role": "user"}] + try: + response = await router.acompletion(**kwargs, stream=True) + print(f"customHandler.previous_models: {customHandler.previous_models}") + await asyncio.sleep( + 0.05 + ) # allow a delay as success_callbacks are on a separate thread + assert customHandler.previous_models == 1 # 0 retries, 1 fallback + router.reset() + except litellm.Timeout as e: + pass + except Exception as e: + pytest.fail(f"An exception occurred: {e}") + finally: + router.reset() diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 5d15e6f2c..9a668fdee 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -271,7 +271,7 @@ def test_completion_azure_stream(): pytest.fail(f"Error occurred: {e}") -# test_completion_azure_stream() +test_completion_azure_stream() def test_completion_azure_function_calling_stream(): diff --git a/litellm/utils.py b/litellm/utils.py index e37bf10d4..800748e04 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6739,7 +6739,10 @@ class CustomStreamWrapper: if str_line.choices[0].finish_reason: is_finished = True finish_reason = str_line.choices[0].finish_reason - if str_line.choices[0].logprobs is not None: + if ( + "logprobs" in str_line.choices[0] + and str_line.choices[0].logprobs is not None + ): logprobs = str_line.choices[0].logprobs else: logprobs = None