diff --git a/litellm/litellm_core_utils/asyncify.py b/litellm/litellm_core_utils/asyncify.py index 5181236e94..8d56a1bbe2 100644 --- a/litellm/litellm_core_utils/asyncify.py +++ b/litellm/litellm_core_utils/asyncify.py @@ -1,3 +1,4 @@ +import asyncio import functools from typing import Awaitable, Callable, Optional @@ -66,3 +67,50 @@ def asyncify( ) return wrapper + + +def run_async_function(async_function, *args, **kwargs): + """ + Helper utility to run an async function in a sync context. + Handles the case where there is an existing event loop running. + + Args: + async_function (Callable): The async function to run + *args: Positional arguments to pass to the async function + **kwargs: Keyword arguments to pass to the async function + + Returns: + The result of the async function execution + + Example: + ```python + async def my_async_func(x, y): + return x + y + + result = run_async_function(my_async_func, 1, 2) + ``` + """ + from concurrent.futures import ThreadPoolExecutor + + def run_in_new_loop(): + """Run the coroutine in a new event loop within this thread.""" + new_loop = asyncio.new_event_loop() + try: + asyncio.set_event_loop(new_loop) + return new_loop.run_until_complete(async_function(*args, **kwargs)) + finally: + new_loop.close() + asyncio.set_event_loop(None) + + try: + # First, try to get the current event loop + _ = asyncio.get_running_loop() + # If we're already in an event loop, run in a separate thread + # to avoid nested event loop issues + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(run_in_new_loop) + return future.result() + + except RuntimeError: + # No running event loop, we can safely run in this thread + return run_in_new_loop() diff --git a/litellm/litellm_core_utils/fallback_utils.py b/litellm/litellm_core_utils/fallback_utils.py new file mode 100644 index 0000000000..852165a830 --- /dev/null +++ b/litellm/litellm_core_utils/fallback_utils.py @@ -0,0 +1,65 @@ +import uuid +from copy import deepcopy + +import litellm +from litellm._logging import verbose_logger + +from .asyncify import run_async_function + + +async def async_completion_with_fallbacks(**kwargs): + """ + Asynchronously attempts completion with fallback models if the primary model fails. + + Args: + **kwargs: Keyword arguments for completion, including: + - model (str): Primary model to use + - fallbacks (List[Union[str, dict]]): List of fallback models/configs + - Other completion parameters + + Returns: + ModelResponse: The completion response from the first successful model + + Raises: + Exception: If all models fail and no response is generated + """ + # Extract and prepare parameters + nested_kwargs = kwargs.pop("kwargs", {}) + original_model = kwargs["model"] + model = original_model + fallbacks = [original_model] + nested_kwargs.pop("fallbacks", []) + kwargs.pop("acompletion", None) # Remove to prevent keyword conflicts + litellm_call_id = str(uuid.uuid4()) + base_kwargs = {**kwargs, **nested_kwargs, "litellm_call_id": litellm_call_id} + base_kwargs.pop("model", None) # Remove model as it will be set per fallback + + # Try each fallback model + for fallback in fallbacks: + try: + completion_kwargs = deepcopy(base_kwargs) + + # Handle dictionary fallback configurations + if isinstance(fallback, dict): + model = fallback.get("model", original_model) + completion_kwargs.update(fallback) + else: + model = fallback + + response = await litellm.acompletion(**completion_kwargs, model=model) + + if response is not None: + return response + + except Exception as e: + verbose_logger.exception( + f"Fallback attempt failed for model {model}: {str(e)}" + ) + continue + + raise Exception( + "All fallback attempts failed. Enable verbose logging with `litellm.set_verbose=True` for details." + ) + + +def completion_with_fallbacks(**kwargs): + return run_async_function(async_function=async_completion_with_fallbacks, **kwargs) diff --git a/litellm/main.py b/litellm/main.py index f094d74877..b6b35969ba 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -75,9 +75,7 @@ from litellm.utils import ( CustomStreamWrapper, ProviderConfigManager, Usage, - async_completion_with_fallbacks, async_mock_completion_streaming_obj, - completion_with_fallbacks, convert_to_model_response_object, create_pretrained_tokenizer, create_tokenizer, @@ -98,6 +96,10 @@ from litellm.utils import ( from ._logging import verbose_logger from .caching.caching import disable_cache, enable_cache, update_cache +from .litellm_core_utils.fallback_utils import ( + async_completion_with_fallbacks, + completion_with_fallbacks, +) from .litellm_core_utils.prompt_templates.common_utils import get_completion_messages from .litellm_core_utils.prompt_templates.factory import ( custom_prompt, diff --git a/litellm/router.py b/litellm/router.py index a15f6a5bcb..cb22ac6d67 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -46,6 +46,7 @@ from litellm import get_secret_str from litellm._logging import verbose_router_logger from litellm.caching.caching import DualCache, InMemoryCache, RedisCache from litellm.integrations.custom_logger import CustomLogger +from litellm.litellm_core_utils.asyncify import run_async_function from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging from litellm.router_strategy.budget_limiter import RouterBudgetLimiting @@ -3264,32 +3265,7 @@ class Router: Wrapped to reduce code duplication and prevent bugs. """ - from concurrent.futures import ThreadPoolExecutor - - def run_in_new_loop(): - """Run the coroutine in a new event loop within this thread.""" - new_loop = asyncio.new_event_loop() - try: - asyncio.set_event_loop(new_loop) - return new_loop.run_until_complete( - self.async_function_with_fallbacks(*args, **kwargs) - ) - finally: - new_loop.close() - asyncio.set_event_loop(None) - - try: - # First, try to get the current event loop - _ = asyncio.get_running_loop() - # If we're already in an event loop, run in a separate thread - # to avoid nested event loop issues - with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(run_in_new_loop) - return future.result() - - except RuntimeError: - # No running event loop, we can safely run in this thread - return run_in_new_loop() + return run_async_function(self.async_function_with_fallbacks, *args, **kwargs) def _get_fallback_model_group_from_fallbacks( self, diff --git a/litellm/utils.py b/litellm/utils.py index 97dc8537ea..84542789e6 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5243,134 +5243,6 @@ def read_config_args(config_path) -> dict: ########## experimental completion variants ############################ -def completion_with_fallbacks(**kwargs): - nested_kwargs = kwargs.pop("kwargs", {}) - response = None - rate_limited_models = set() - model_expiration_times = {} - start_time = time.time() - original_model = kwargs["model"] - fallbacks = [kwargs["model"]] + nested_kwargs.get("fallbacks", []) - if "fallbacks" in nested_kwargs: - del nested_kwargs["fallbacks"] # remove fallbacks so it's not recursive - litellm_call_id = str(uuid.uuid4()) - - # max time to process a request with fallbacks: default 45s - while response is None and time.time() - start_time < 45: - for model in fallbacks: - # loop thru all models - try: - # check if it's dict or new model string - if isinstance( - model, dict - ): # completion(model="gpt-4", fallbacks=[{"api_key": "", "api_base": ""}, {"api_key": "", "api_base": ""}]) - kwargs["api_key"] = model.get("api_key", None) - kwargs["api_base"] = model.get("api_base", None) - model = model.get("model", original_model) - elif ( - model in rate_limited_models - ): # check if model is currently cooling down - if ( - model_expiration_times.get(model) - and time.time() >= model_expiration_times[model] - ): - rate_limited_models.remove( - model - ) # check if it's been 60s of cool down and remove model - else: - continue # skip model - - # delete model from kwargs if it exists - if kwargs.get("model"): - del kwargs["model"] - - print_verbose(f"trying to make completion call with model: {model}") - kwargs["litellm_call_id"] = litellm_call_id - kwargs = { - **kwargs, - **nested_kwargs, - } # combine the openai + litellm params at the same level - response = litellm.completion(**kwargs, model=model) - print_verbose(f"response: {response}") - if response is not None: - return response - - except Exception as e: - print_verbose(e) - rate_limited_models.add(model) - model_expiration_times[model] = ( - time.time() + 60 - ) # cool down this selected model - pass - return response - - -async def async_completion_with_fallbacks(**kwargs): - nested_kwargs = kwargs.pop("kwargs", {}) - response = None - rate_limited_models = set() - model_expiration_times = {} - start_time = time.time() - original_model = kwargs["model"] - fallbacks = [kwargs["model"]] + nested_kwargs.get("fallbacks", []) - if "fallbacks" in nested_kwargs: - del nested_kwargs["fallbacks"] # remove fallbacks so it's not recursive - if "acompletion" in kwargs: - del kwargs[ - "acompletion" - ] # remove acompletion so it doesn't lead to keyword errors - litellm_call_id = str(uuid.uuid4()) - - # max time to process a request with fallbacks: default 45s - while response is None and time.time() - start_time < 45: - for model in fallbacks: - # loop thru all models - try: - # check if it's dict or new model string - if isinstance( - model, dict - ): # completion(model="gpt-4", fallbacks=[{"api_key": "", "api_base": ""}, {"api_key": "", "api_base": ""}]) - kwargs["api_key"] = model.get("api_key", None) - kwargs["api_base"] = model.get("api_base", None) - model = model.get("model", original_model) - elif ( - model in rate_limited_models - ): # check if model is currently cooling down - if ( - model_expiration_times.get(model) - and time.time() >= model_expiration_times[model] - ): - rate_limited_models.remove( - model - ) # check if it's been 60s of cool down and remove model - else: - continue # skip model - - # delete model from kwargs if it exists - if kwargs.get("model"): - del kwargs["model"] - - print_verbose(f"trying to make completion call with model: {model}") - kwargs["litellm_call_id"] = litellm_call_id - kwargs = { - **kwargs, - **nested_kwargs, - } # combine the openai + litellm params at the same level - response = await litellm.acompletion(**kwargs, model=model) - print_verbose(f"response: {response}") - if response is not None: - return response - - except Exception as e: - print_verbose(f"error: {e}") - rate_limited_models.add(model) - model_expiration_times[model] = ( - time.time() + 60 - ) # cool down this selected model - pass - return response - - def process_system_message(system_message, max_tokens, model): system_message_event = {"role": "system", "content": system_message} system_message_tokens = get_token_count([system_message_event], model) diff --git a/tests/local_testing/test_acompletion_fallbacks.py b/tests/local_testing/test_acompletion_fallbacks.py new file mode 100644 index 0000000000..00c2139f27 --- /dev/null +++ b/tests/local_testing/test_acompletion_fallbacks.py @@ -0,0 +1,103 @@ +import asyncio +import os +import sys +import time +import traceback + +import pytest + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import concurrent + +from dotenv import load_dotenv +import asyncio +import litellm + + +@pytest.mark.asyncio +async def test_acompletion_fallbacks_basic(): + response = await litellm.acompletion( + model="openai/unknown-model", + messages=[{"role": "user", "content": "Hello, world!"}], + fallbacks=["openai/gpt-4o-mini"], + ) + print(response) + assert response is not None + + +@pytest.mark.asyncio +async def test_acompletion_fallbacks_bad_models(): + """ + Test that the acompletion call times out after 10 seconds - if no fallbacks work + """ + try: + # Wrap the acompletion call with asyncio.wait_for to enforce a timeout + response = await asyncio.wait_for( + litellm.acompletion( + model="openai/unknown-model", + messages=[{"role": "user", "content": "Hello, world!"}], + fallbacks=["openai/bad-model", "openai/unknown-model"], + ), + timeout=5.0, # Timeout after 5 seconds + ) + assert response is not None + except asyncio.TimeoutError: + pytest.fail("Test timed out - possible infinite loop in fallbacks") + except Exception as e: + print(e) + pass + + +@pytest.mark.asyncio +async def test_acompletion_fallbacks_with_dict_config(): + """ + Test fallbacks with dictionary configuration that includes model-specific settings + """ + response = await litellm.acompletion( + model="openai/gpt-4o-mini", + messages=[{"role": "user", "content": "Hello, world!"}], + api_key="very-bad-api-key", + fallbacks=[{"api_key": os.getenv("OPENAI_API_KEY")}], + ) + assert response is not None + + +@pytest.mark.asyncio +async def test_acompletion_fallbacks_empty_list(): + """ + Test behavior when fallbacks list is empty + """ + try: + response = await litellm.acompletion( + model="openai/unknown-model", + messages=[{"role": "user", "content": "Hello, world!"}], + fallbacks=[], + ) + except Exception as e: + assert isinstance(e, litellm.NotFoundError) + + +@pytest.mark.asyncio +async def test_acompletion_fallbacks_none_response(): + """ + Test handling when a fallback model returns None + Should continue to next fallback rather than returning None + """ + response = await litellm.acompletion( + model="openai/unknown-model", + messages=[{"role": "user", "content": "Hello, world!"}], + fallbacks=["gpt-3.5-turbo"], # replace with a model you know works + ) + assert response is not None + + +async def test_completion_fallbacks_sync(): + response = litellm.completion( + model="openai/unknown-model", + messages=[{"role": "user", "content": "Hello, world!"}], + fallbacks=["openai/gpt-4o-mini"], + ) + print(response) + assert response is not None