mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
(core sdk fix) - fix fallbacks stuck in infinite loop (#7751)
* test_acompletion_fallbacks_basic * use common run_async_function * fix completion_with_fallbacks * fix completion with fallbacks * fix fallback utils * test_acompletion_fallbacks_basic * test_completion_fallbacks_sync * huggingface/mistralai/Mistral-7B-Instruct-v0.3
This commit is contained in:
parent
a66fd515bb
commit
392eb265f9
6 changed files with 222 additions and 156 deletions
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
import functools
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
|
@ -66,3 +67,50 @@ def asyncify(
|
|||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def run_async_function(async_function, *args, **kwargs):
|
||||
"""
|
||||
Helper utility to run an async function in a sync context.
|
||||
Handles the case where there is an existing event loop running.
|
||||
|
||||
Args:
|
||||
async_function (Callable): The async function to run
|
||||
*args: Positional arguments to pass to the async function
|
||||
**kwargs: Keyword arguments to pass to the async function
|
||||
|
||||
Returns:
|
||||
The result of the async function execution
|
||||
|
||||
Example:
|
||||
```python
|
||||
async def my_async_func(x, y):
|
||||
return x + y
|
||||
|
||||
result = run_async_function(my_async_func, 1, 2)
|
||||
```
|
||||
"""
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
def run_in_new_loop():
|
||||
"""Run the coroutine in a new event loop within this thread."""
|
||||
new_loop = asyncio.new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(new_loop)
|
||||
return new_loop.run_until_complete(async_function(*args, **kwargs))
|
||||
finally:
|
||||
new_loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
try:
|
||||
# First, try to get the current event loop
|
||||
_ = asyncio.get_running_loop()
|
||||
# If we're already in an event loop, run in a separate thread
|
||||
# to avoid nested event loop issues
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
return future.result()
|
||||
|
||||
except RuntimeError:
|
||||
# No running event loop, we can safely run in this thread
|
||||
return run_in_new_loop()
|
||||
|
|
65
litellm/litellm_core_utils/fallback_utils.py
Normal file
65
litellm/litellm_core_utils/fallback_utils.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
import uuid
|
||||
from copy import deepcopy
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
from .asyncify import run_async_function
|
||||
|
||||
|
||||
async def async_completion_with_fallbacks(**kwargs):
|
||||
"""
|
||||
Asynchronously attempts completion with fallback models if the primary model fails.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments for completion, including:
|
||||
- model (str): Primary model to use
|
||||
- fallbacks (List[Union[str, dict]]): List of fallback models/configs
|
||||
- Other completion parameters
|
||||
|
||||
Returns:
|
||||
ModelResponse: The completion response from the first successful model
|
||||
|
||||
Raises:
|
||||
Exception: If all models fail and no response is generated
|
||||
"""
|
||||
# Extract and prepare parameters
|
||||
nested_kwargs = kwargs.pop("kwargs", {})
|
||||
original_model = kwargs["model"]
|
||||
model = original_model
|
||||
fallbacks = [original_model] + nested_kwargs.pop("fallbacks", [])
|
||||
kwargs.pop("acompletion", None) # Remove to prevent keyword conflicts
|
||||
litellm_call_id = str(uuid.uuid4())
|
||||
base_kwargs = {**kwargs, **nested_kwargs, "litellm_call_id": litellm_call_id}
|
||||
base_kwargs.pop("model", None) # Remove model as it will be set per fallback
|
||||
|
||||
# Try each fallback model
|
||||
for fallback in fallbacks:
|
||||
try:
|
||||
completion_kwargs = deepcopy(base_kwargs)
|
||||
|
||||
# Handle dictionary fallback configurations
|
||||
if isinstance(fallback, dict):
|
||||
model = fallback.get("model", original_model)
|
||||
completion_kwargs.update(fallback)
|
||||
else:
|
||||
model = fallback
|
||||
|
||||
response = await litellm.acompletion(**completion_kwargs, model=model)
|
||||
|
||||
if response is not None:
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Fallback attempt failed for model {model}: {str(e)}"
|
||||
)
|
||||
continue
|
||||
|
||||
raise Exception(
|
||||
"All fallback attempts failed. Enable verbose logging with `litellm.set_verbose=True` for details."
|
||||
)
|
||||
|
||||
|
||||
def completion_with_fallbacks(**kwargs):
|
||||
return run_async_function(async_function=async_completion_with_fallbacks, **kwargs)
|
Loading…
Add table
Add a link
Reference in a new issue