(core sdk fix) - fix fallbacks stuck in infinite loop (#7751)

* test_acompletion_fallbacks_basic

* use common run_async_function

* fix completion_with_fallbacks

* fix completion with fallbacks

* fix fallback utils

* test_acompletion_fallbacks_basic

* test_completion_fallbacks_sync

* huggingface/mistralai/Mistral-7B-Instruct-v0.3
This commit is contained in:
Ishaan Jaff 2025-01-13 19:34:34 -08:00 committed by GitHub
parent 970e9c7507
commit f1335362cf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 222 additions and 156 deletions

View file

@ -1,3 +1,4 @@
import asyncio
import functools
from typing import Awaitable, Callable, Optional
@ -66,3 +67,50 @@ def asyncify(
)
return wrapper
def run_async_function(async_function, *args, **kwargs):
"""
Helper utility to run an async function in a sync context.
Handles the case where there is an existing event loop running.
Args:
async_function (Callable): The async function to run
*args: Positional arguments to pass to the async function
**kwargs: Keyword arguments to pass to the async function
Returns:
The result of the async function execution
Example:
```python
async def my_async_func(x, y):
return x + y
result = run_async_function(my_async_func, 1, 2)
```
"""
from concurrent.futures import ThreadPoolExecutor
def run_in_new_loop():
"""Run the coroutine in a new event loop within this thread."""
new_loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(new_loop)
return new_loop.run_until_complete(async_function(*args, **kwargs))
finally:
new_loop.close()
asyncio.set_event_loop(None)
try:
# First, try to get the current event loop
_ = asyncio.get_running_loop()
# If we're already in an event loop, run in a separate thread
# to avoid nested event loop issues
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_in_new_loop)
return future.result()
except RuntimeError:
# No running event loop, we can safely run in this thread
return run_in_new_loop()

View file

@ -0,0 +1,65 @@
import uuid
from copy import deepcopy
import litellm
from litellm._logging import verbose_logger
from .asyncify import run_async_function
async def async_completion_with_fallbacks(**kwargs):
"""
Asynchronously attempts completion with fallback models if the primary model fails.
Args:
**kwargs: Keyword arguments for completion, including:
- model (str): Primary model to use
- fallbacks (List[Union[str, dict]]): List of fallback models/configs
- Other completion parameters
Returns:
ModelResponse: The completion response from the first successful model
Raises:
Exception: If all models fail and no response is generated
"""
# Extract and prepare parameters
nested_kwargs = kwargs.pop("kwargs", {})
original_model = kwargs["model"]
model = original_model
fallbacks = [original_model] + nested_kwargs.pop("fallbacks", [])
kwargs.pop("acompletion", None) # Remove to prevent keyword conflicts
litellm_call_id = str(uuid.uuid4())
base_kwargs = {**kwargs, **nested_kwargs, "litellm_call_id": litellm_call_id}
base_kwargs.pop("model", None) # Remove model as it will be set per fallback
# Try each fallback model
for fallback in fallbacks:
try:
completion_kwargs = deepcopy(base_kwargs)
# Handle dictionary fallback configurations
if isinstance(fallback, dict):
model = fallback.get("model", original_model)
completion_kwargs.update(fallback)
else:
model = fallback
response = await litellm.acompletion(**completion_kwargs, model=model)
if response is not None:
return response
except Exception as e:
verbose_logger.exception(
f"Fallback attempt failed for model {model}: {str(e)}"
)
continue
raise Exception(
"All fallback attempts failed. Enable verbose logging with `litellm.set_verbose=True` for details."
)
def completion_with_fallbacks(**kwargs):
return run_async_function(async_function=async_completion_with_fallbacks, **kwargs)

View file

@ -75,9 +75,7 @@ from litellm.utils import (
CustomStreamWrapper,
ProviderConfigManager,
Usage,
async_completion_with_fallbacks,
async_mock_completion_streaming_obj,
completion_with_fallbacks,
convert_to_model_response_object,
create_pretrained_tokenizer,
create_tokenizer,
@ -98,6 +96,10 @@ from litellm.utils import (
from ._logging import verbose_logger
from .caching.caching import disable_cache, enable_cache, update_cache
from .litellm_core_utils.fallback_utils import (
async_completion_with_fallbacks,
completion_with_fallbacks,
)
from .litellm_core_utils.prompt_templates.common_utils import get_completion_messages
from .litellm_core_utils.prompt_templates.factory import (
custom_prompt,

View file

@ -46,6 +46,7 @@ from litellm import get_secret_str
from litellm._logging import verbose_router_logger
from litellm.caching.caching import DualCache, InMemoryCache, RedisCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.asyncify import run_async_function
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
@ -3264,32 +3265,7 @@ class Router:
Wrapped to reduce code duplication and prevent bugs.
"""
from concurrent.futures import ThreadPoolExecutor
def run_in_new_loop():
"""Run the coroutine in a new event loop within this thread."""
new_loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(new_loop)
return new_loop.run_until_complete(
self.async_function_with_fallbacks(*args, **kwargs)
)
finally:
new_loop.close()
asyncio.set_event_loop(None)
try:
# First, try to get the current event loop
_ = asyncio.get_running_loop()
# If we're already in an event loop, run in a separate thread
# to avoid nested event loop issues
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_in_new_loop)
return future.result()
except RuntimeError:
# No running event loop, we can safely run in this thread
return run_in_new_loop()
return run_async_function(self.async_function_with_fallbacks, *args, **kwargs)
def _get_fallback_model_group_from_fallbacks(
self,

View file

@ -5243,134 +5243,6 @@ def read_config_args(config_path) -> dict:
########## experimental completion variants ############################
def completion_with_fallbacks(**kwargs):
nested_kwargs = kwargs.pop("kwargs", {})
response = None
rate_limited_models = set()
model_expiration_times = {}
start_time = time.time()
original_model = kwargs["model"]
fallbacks = [kwargs["model"]] + nested_kwargs.get("fallbacks", [])
if "fallbacks" in nested_kwargs:
del nested_kwargs["fallbacks"] # remove fallbacks so it's not recursive
litellm_call_id = str(uuid.uuid4())
# max time to process a request with fallbacks: default 45s
while response is None and time.time() - start_time < 45:
for model in fallbacks:
# loop thru all models
try:
# check if it's dict or new model string
if isinstance(
model, dict
): # completion(model="gpt-4", fallbacks=[{"api_key": "", "api_base": ""}, {"api_key": "", "api_base": ""}])
kwargs["api_key"] = model.get("api_key", None)
kwargs["api_base"] = model.get("api_base", None)
model = model.get("model", original_model)
elif (
model in rate_limited_models
): # check if model is currently cooling down
if (
model_expiration_times.get(model)
and time.time() >= model_expiration_times[model]
):
rate_limited_models.remove(
model
) # check if it's been 60s of cool down and remove model
else:
continue # skip model
# delete model from kwargs if it exists
if kwargs.get("model"):
del kwargs["model"]
print_verbose(f"trying to make completion call with model: {model}")
kwargs["litellm_call_id"] = litellm_call_id
kwargs = {
**kwargs,
**nested_kwargs,
} # combine the openai + litellm params at the same level
response = litellm.completion(**kwargs, model=model)
print_verbose(f"response: {response}")
if response is not None:
return response
except Exception as e:
print_verbose(e)
rate_limited_models.add(model)
model_expiration_times[model] = (
time.time() + 60
) # cool down this selected model
pass
return response
async def async_completion_with_fallbacks(**kwargs):
nested_kwargs = kwargs.pop("kwargs", {})
response = None
rate_limited_models = set()
model_expiration_times = {}
start_time = time.time()
original_model = kwargs["model"]
fallbacks = [kwargs["model"]] + nested_kwargs.get("fallbacks", [])
if "fallbacks" in nested_kwargs:
del nested_kwargs["fallbacks"] # remove fallbacks so it's not recursive
if "acompletion" in kwargs:
del kwargs[
"acompletion"
] # remove acompletion so it doesn't lead to keyword errors
litellm_call_id = str(uuid.uuid4())
# max time to process a request with fallbacks: default 45s
while response is None and time.time() - start_time < 45:
for model in fallbacks:
# loop thru all models
try:
# check if it's dict or new model string
if isinstance(
model, dict
): # completion(model="gpt-4", fallbacks=[{"api_key": "", "api_base": ""}, {"api_key": "", "api_base": ""}])
kwargs["api_key"] = model.get("api_key", None)
kwargs["api_base"] = model.get("api_base", None)
model = model.get("model", original_model)
elif (
model in rate_limited_models
): # check if model is currently cooling down
if (
model_expiration_times.get(model)
and time.time() >= model_expiration_times[model]
):
rate_limited_models.remove(
model
) # check if it's been 60s of cool down and remove model
else:
continue # skip model
# delete model from kwargs if it exists
if kwargs.get("model"):
del kwargs["model"]
print_verbose(f"trying to make completion call with model: {model}")
kwargs["litellm_call_id"] = litellm_call_id
kwargs = {
**kwargs,
**nested_kwargs,
} # combine the openai + litellm params at the same level
response = await litellm.acompletion(**kwargs, model=model)
print_verbose(f"response: {response}")
if response is not None:
return response
except Exception as e:
print_verbose(f"error: {e}")
rate_limited_models.add(model)
model_expiration_times[model] = (
time.time() + 60
) # cool down this selected model
pass
return response
def process_system_message(system_message, max_tokens, model):
system_message_event = {"role": "system", "content": system_message}
system_message_tokens = get_token_count([system_message_event], model)

View file

@ -0,0 +1,103 @@
import asyncio
import os
import sys
import time
import traceback
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import concurrent
from dotenv import load_dotenv
import asyncio
import litellm
@pytest.mark.asyncio
async def test_acompletion_fallbacks_basic():
response = await litellm.acompletion(
model="openai/unknown-model",
messages=[{"role": "user", "content": "Hello, world!"}],
fallbacks=["openai/gpt-4o-mini"],
)
print(response)
assert response is not None
@pytest.mark.asyncio
async def test_acompletion_fallbacks_bad_models():
"""
Test that the acompletion call times out after 10 seconds - if no fallbacks work
"""
try:
# Wrap the acompletion call with asyncio.wait_for to enforce a timeout
response = await asyncio.wait_for(
litellm.acompletion(
model="openai/unknown-model",
messages=[{"role": "user", "content": "Hello, world!"}],
fallbacks=["openai/bad-model", "openai/unknown-model"],
),
timeout=5.0, # Timeout after 5 seconds
)
assert response is not None
except asyncio.TimeoutError:
pytest.fail("Test timed out - possible infinite loop in fallbacks")
except Exception as e:
print(e)
pass
@pytest.mark.asyncio
async def test_acompletion_fallbacks_with_dict_config():
"""
Test fallbacks with dictionary configuration that includes model-specific settings
"""
response = await litellm.acompletion(
model="openai/gpt-4o-mini",
messages=[{"role": "user", "content": "Hello, world!"}],
api_key="very-bad-api-key",
fallbacks=[{"api_key": os.getenv("OPENAI_API_KEY")}],
)
assert response is not None
@pytest.mark.asyncio
async def test_acompletion_fallbacks_empty_list():
"""
Test behavior when fallbacks list is empty
"""
try:
response = await litellm.acompletion(
model="openai/unknown-model",
messages=[{"role": "user", "content": "Hello, world!"}],
fallbacks=[],
)
except Exception as e:
assert isinstance(e, litellm.NotFoundError)
@pytest.mark.asyncio
async def test_acompletion_fallbacks_none_response():
"""
Test handling when a fallback model returns None
Should continue to next fallback rather than returning None
"""
response = await litellm.acompletion(
model="openai/unknown-model",
messages=[{"role": "user", "content": "Hello, world!"}],
fallbacks=["gpt-3.5-turbo"], # replace with a model you know works
)
assert response is not None
async def test_completion_fallbacks_sync():
response = litellm.completion(
model="openai/unknown-model",
messages=[{"role": "user", "content": "Hello, world!"}],
fallbacks=["openai/gpt-4o-mini"],
)
print(response)
assert response is not None