mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
(core sdk fix) - fix fallbacks stuck in infinite loop (#7751)
* test_acompletion_fallbacks_basic * use common run_async_function * fix completion_with_fallbacks * fix completion with fallbacks * fix fallback utils * test_acompletion_fallbacks_basic * test_completion_fallbacks_sync * huggingface/mistralai/Mistral-7B-Instruct-v0.3
This commit is contained in:
parent
a66fd515bb
commit
392eb265f9
6 changed files with 222 additions and 156 deletions
|
@ -1,3 +1,4 @@
|
||||||
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
from typing import Awaitable, Callable, Optional
|
from typing import Awaitable, Callable, Optional
|
||||||
|
|
||||||
|
@ -66,3 +67,50 @@ def asyncify(
|
||||||
)
|
)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def run_async_function(async_function, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Helper utility to run an async function in a sync context.
|
||||||
|
Handles the case where there is an existing event loop running.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
async_function (Callable): The async function to run
|
||||||
|
*args: Positional arguments to pass to the async function
|
||||||
|
**kwargs: Keyword arguments to pass to the async function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result of the async function execution
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
async def my_async_func(x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
result = run_async_function(my_async_func, 1, 2)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
def run_in_new_loop():
|
||||||
|
"""Run the coroutine in a new event loop within this thread."""
|
||||||
|
new_loop = asyncio.new_event_loop()
|
||||||
|
try:
|
||||||
|
asyncio.set_event_loop(new_loop)
|
||||||
|
return new_loop.run_until_complete(async_function(*args, **kwargs))
|
||||||
|
finally:
|
||||||
|
new_loop.close()
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# First, try to get the current event loop
|
||||||
|
_ = asyncio.get_running_loop()
|
||||||
|
# If we're already in an event loop, run in a separate thread
|
||||||
|
# to avoid nested event loop issues
|
||||||
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
|
future = executor.submit(run_in_new_loop)
|
||||||
|
return future.result()
|
||||||
|
|
||||||
|
except RuntimeError:
|
||||||
|
# No running event loop, we can safely run in this thread
|
||||||
|
return run_in_new_loop()
|
||||||
|
|
65
litellm/litellm_core_utils/fallback_utils.py
Normal file
65
litellm/litellm_core_utils/fallback_utils.py
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
import uuid
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
|
||||||
|
from .asyncify import run_async_function
|
||||||
|
|
||||||
|
|
||||||
|
async def async_completion_with_fallbacks(**kwargs):
|
||||||
|
"""
|
||||||
|
Asynchronously attempts completion with fallback models if the primary model fails.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Keyword arguments for completion, including:
|
||||||
|
- model (str): Primary model to use
|
||||||
|
- fallbacks (List[Union[str, dict]]): List of fallback models/configs
|
||||||
|
- Other completion parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelResponse: The completion response from the first successful model
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If all models fail and no response is generated
|
||||||
|
"""
|
||||||
|
# Extract and prepare parameters
|
||||||
|
nested_kwargs = kwargs.pop("kwargs", {})
|
||||||
|
original_model = kwargs["model"]
|
||||||
|
model = original_model
|
||||||
|
fallbacks = [original_model] + nested_kwargs.pop("fallbacks", [])
|
||||||
|
kwargs.pop("acompletion", None) # Remove to prevent keyword conflicts
|
||||||
|
litellm_call_id = str(uuid.uuid4())
|
||||||
|
base_kwargs = {**kwargs, **nested_kwargs, "litellm_call_id": litellm_call_id}
|
||||||
|
base_kwargs.pop("model", None) # Remove model as it will be set per fallback
|
||||||
|
|
||||||
|
# Try each fallback model
|
||||||
|
for fallback in fallbacks:
|
||||||
|
try:
|
||||||
|
completion_kwargs = deepcopy(base_kwargs)
|
||||||
|
|
||||||
|
# Handle dictionary fallback configurations
|
||||||
|
if isinstance(fallback, dict):
|
||||||
|
model = fallback.get("model", original_model)
|
||||||
|
completion_kwargs.update(fallback)
|
||||||
|
else:
|
||||||
|
model = fallback
|
||||||
|
|
||||||
|
response = await litellm.acompletion(**completion_kwargs, model=model)
|
||||||
|
|
||||||
|
if response is not None:
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.exception(
|
||||||
|
f"Fallback attempt failed for model {model}: {str(e)}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise Exception(
|
||||||
|
"All fallback attempts failed. Enable verbose logging with `litellm.set_verbose=True` for details."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def completion_with_fallbacks(**kwargs):
|
||||||
|
return run_async_function(async_function=async_completion_with_fallbacks, **kwargs)
|
|
@ -75,9 +75,7 @@ from litellm.utils import (
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
ProviderConfigManager,
|
ProviderConfigManager,
|
||||||
Usage,
|
Usage,
|
||||||
async_completion_with_fallbacks,
|
|
||||||
async_mock_completion_streaming_obj,
|
async_mock_completion_streaming_obj,
|
||||||
completion_with_fallbacks,
|
|
||||||
convert_to_model_response_object,
|
convert_to_model_response_object,
|
||||||
create_pretrained_tokenizer,
|
create_pretrained_tokenizer,
|
||||||
create_tokenizer,
|
create_tokenizer,
|
||||||
|
@ -98,6 +96,10 @@ from litellm.utils import (
|
||||||
|
|
||||||
from ._logging import verbose_logger
|
from ._logging import verbose_logger
|
||||||
from .caching.caching import disable_cache, enable_cache, update_cache
|
from .caching.caching import disable_cache, enable_cache, update_cache
|
||||||
|
from .litellm_core_utils.fallback_utils import (
|
||||||
|
async_completion_with_fallbacks,
|
||||||
|
completion_with_fallbacks,
|
||||||
|
)
|
||||||
from .litellm_core_utils.prompt_templates.common_utils import get_completion_messages
|
from .litellm_core_utils.prompt_templates.common_utils import get_completion_messages
|
||||||
from .litellm_core_utils.prompt_templates.factory import (
|
from .litellm_core_utils.prompt_templates.factory import (
|
||||||
custom_prompt,
|
custom_prompt,
|
||||||
|
|
|
@ -46,6 +46,7 @@ from litellm import get_secret_str
|
||||||
from litellm._logging import verbose_router_logger
|
from litellm._logging import verbose_router_logger
|
||||||
from litellm.caching.caching import DualCache, InMemoryCache, RedisCache
|
from litellm.caching.caching import DualCache, InMemoryCache, RedisCache
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.litellm_core_utils.asyncify import run_async_function
|
||||||
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
|
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||||
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
|
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
|
||||||
|
@ -3264,32 +3265,7 @@ class Router:
|
||||||
|
|
||||||
Wrapped to reduce code duplication and prevent bugs.
|
Wrapped to reduce code duplication and prevent bugs.
|
||||||
"""
|
"""
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
return run_async_function(self.async_function_with_fallbacks, *args, **kwargs)
|
||||||
|
|
||||||
def run_in_new_loop():
|
|
||||||
"""Run the coroutine in a new event loop within this thread."""
|
|
||||||
new_loop = asyncio.new_event_loop()
|
|
||||||
try:
|
|
||||||
asyncio.set_event_loop(new_loop)
|
|
||||||
return new_loop.run_until_complete(
|
|
||||||
self.async_function_with_fallbacks(*args, **kwargs)
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
new_loop.close()
|
|
||||||
asyncio.set_event_loop(None)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# First, try to get the current event loop
|
|
||||||
_ = asyncio.get_running_loop()
|
|
||||||
# If we're already in an event loop, run in a separate thread
|
|
||||||
# to avoid nested event loop issues
|
|
||||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
|
||||||
future = executor.submit(run_in_new_loop)
|
|
||||||
return future.result()
|
|
||||||
|
|
||||||
except RuntimeError:
|
|
||||||
# No running event loop, we can safely run in this thread
|
|
||||||
return run_in_new_loop()
|
|
||||||
|
|
||||||
def _get_fallback_model_group_from_fallbacks(
|
def _get_fallback_model_group_from_fallbacks(
|
||||||
self,
|
self,
|
||||||
|
|
128
litellm/utils.py
128
litellm/utils.py
|
@ -5243,134 +5243,6 @@ def read_config_args(config_path) -> dict:
|
||||||
########## experimental completion variants ############################
|
########## experimental completion variants ############################
|
||||||
|
|
||||||
|
|
||||||
def completion_with_fallbacks(**kwargs):
|
|
||||||
nested_kwargs = kwargs.pop("kwargs", {})
|
|
||||||
response = None
|
|
||||||
rate_limited_models = set()
|
|
||||||
model_expiration_times = {}
|
|
||||||
start_time = time.time()
|
|
||||||
original_model = kwargs["model"]
|
|
||||||
fallbacks = [kwargs["model"]] + nested_kwargs.get("fallbacks", [])
|
|
||||||
if "fallbacks" in nested_kwargs:
|
|
||||||
del nested_kwargs["fallbacks"] # remove fallbacks so it's not recursive
|
|
||||||
litellm_call_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
# max time to process a request with fallbacks: default 45s
|
|
||||||
while response is None and time.time() - start_time < 45:
|
|
||||||
for model in fallbacks:
|
|
||||||
# loop thru all models
|
|
||||||
try:
|
|
||||||
# check if it's dict or new model string
|
|
||||||
if isinstance(
|
|
||||||
model, dict
|
|
||||||
): # completion(model="gpt-4", fallbacks=[{"api_key": "", "api_base": ""}, {"api_key": "", "api_base": ""}])
|
|
||||||
kwargs["api_key"] = model.get("api_key", None)
|
|
||||||
kwargs["api_base"] = model.get("api_base", None)
|
|
||||||
model = model.get("model", original_model)
|
|
||||||
elif (
|
|
||||||
model in rate_limited_models
|
|
||||||
): # check if model is currently cooling down
|
|
||||||
if (
|
|
||||||
model_expiration_times.get(model)
|
|
||||||
and time.time() >= model_expiration_times[model]
|
|
||||||
):
|
|
||||||
rate_limited_models.remove(
|
|
||||||
model
|
|
||||||
) # check if it's been 60s of cool down and remove model
|
|
||||||
else:
|
|
||||||
continue # skip model
|
|
||||||
|
|
||||||
# delete model from kwargs if it exists
|
|
||||||
if kwargs.get("model"):
|
|
||||||
del kwargs["model"]
|
|
||||||
|
|
||||||
print_verbose(f"trying to make completion call with model: {model}")
|
|
||||||
kwargs["litellm_call_id"] = litellm_call_id
|
|
||||||
kwargs = {
|
|
||||||
**kwargs,
|
|
||||||
**nested_kwargs,
|
|
||||||
} # combine the openai + litellm params at the same level
|
|
||||||
response = litellm.completion(**kwargs, model=model)
|
|
||||||
print_verbose(f"response: {response}")
|
|
||||||
if response is not None:
|
|
||||||
return response
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print_verbose(e)
|
|
||||||
rate_limited_models.add(model)
|
|
||||||
model_expiration_times[model] = (
|
|
||||||
time.time() + 60
|
|
||||||
) # cool down this selected model
|
|
||||||
pass
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
async def async_completion_with_fallbacks(**kwargs):
|
|
||||||
nested_kwargs = kwargs.pop("kwargs", {})
|
|
||||||
response = None
|
|
||||||
rate_limited_models = set()
|
|
||||||
model_expiration_times = {}
|
|
||||||
start_time = time.time()
|
|
||||||
original_model = kwargs["model"]
|
|
||||||
fallbacks = [kwargs["model"]] + nested_kwargs.get("fallbacks", [])
|
|
||||||
if "fallbacks" in nested_kwargs:
|
|
||||||
del nested_kwargs["fallbacks"] # remove fallbacks so it's not recursive
|
|
||||||
if "acompletion" in kwargs:
|
|
||||||
del kwargs[
|
|
||||||
"acompletion"
|
|
||||||
] # remove acompletion so it doesn't lead to keyword errors
|
|
||||||
litellm_call_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
# max time to process a request with fallbacks: default 45s
|
|
||||||
while response is None and time.time() - start_time < 45:
|
|
||||||
for model in fallbacks:
|
|
||||||
# loop thru all models
|
|
||||||
try:
|
|
||||||
# check if it's dict or new model string
|
|
||||||
if isinstance(
|
|
||||||
model, dict
|
|
||||||
): # completion(model="gpt-4", fallbacks=[{"api_key": "", "api_base": ""}, {"api_key": "", "api_base": ""}])
|
|
||||||
kwargs["api_key"] = model.get("api_key", None)
|
|
||||||
kwargs["api_base"] = model.get("api_base", None)
|
|
||||||
model = model.get("model", original_model)
|
|
||||||
elif (
|
|
||||||
model in rate_limited_models
|
|
||||||
): # check if model is currently cooling down
|
|
||||||
if (
|
|
||||||
model_expiration_times.get(model)
|
|
||||||
and time.time() >= model_expiration_times[model]
|
|
||||||
):
|
|
||||||
rate_limited_models.remove(
|
|
||||||
model
|
|
||||||
) # check if it's been 60s of cool down and remove model
|
|
||||||
else:
|
|
||||||
continue # skip model
|
|
||||||
|
|
||||||
# delete model from kwargs if it exists
|
|
||||||
if kwargs.get("model"):
|
|
||||||
del kwargs["model"]
|
|
||||||
|
|
||||||
print_verbose(f"trying to make completion call with model: {model}")
|
|
||||||
kwargs["litellm_call_id"] = litellm_call_id
|
|
||||||
kwargs = {
|
|
||||||
**kwargs,
|
|
||||||
**nested_kwargs,
|
|
||||||
} # combine the openai + litellm params at the same level
|
|
||||||
response = await litellm.acompletion(**kwargs, model=model)
|
|
||||||
print_verbose(f"response: {response}")
|
|
||||||
if response is not None:
|
|
||||||
return response
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print_verbose(f"error: {e}")
|
|
||||||
rate_limited_models.add(model)
|
|
||||||
model_expiration_times[model] = (
|
|
||||||
time.time() + 60
|
|
||||||
) # cool down this selected model
|
|
||||||
pass
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
def process_system_message(system_message, max_tokens, model):
|
def process_system_message(system_message, max_tokens, model):
|
||||||
system_message_event = {"role": "system", "content": system_message}
|
system_message_event = {"role": "system", "content": system_message}
|
||||||
system_message_tokens = get_token_count([system_message_event], model)
|
system_message_tokens = get_token_count([system_message_event], model)
|
||||||
|
|
103
tests/local_testing/test_acompletion_fallbacks.py
Normal file
103
tests/local_testing/test_acompletion_fallbacks.py
Normal file
|
@ -0,0 +1,103 @@
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import concurrent
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import asyncio
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_acompletion_fallbacks_basic():
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="openai/unknown-model",
|
||||||
|
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||||
|
fallbacks=["openai/gpt-4o-mini"],
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
assert response is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_acompletion_fallbacks_bad_models():
|
||||||
|
"""
|
||||||
|
Test that the acompletion call times out after 10 seconds - if no fallbacks work
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Wrap the acompletion call with asyncio.wait_for to enforce a timeout
|
||||||
|
response = await asyncio.wait_for(
|
||||||
|
litellm.acompletion(
|
||||||
|
model="openai/unknown-model",
|
||||||
|
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||||
|
fallbacks=["openai/bad-model", "openai/unknown-model"],
|
||||||
|
),
|
||||||
|
timeout=5.0, # Timeout after 5 seconds
|
||||||
|
)
|
||||||
|
assert response is not None
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pytest.fail("Test timed out - possible infinite loop in fallbacks")
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_acompletion_fallbacks_with_dict_config():
|
||||||
|
"""
|
||||||
|
Test fallbacks with dictionary configuration that includes model-specific settings
|
||||||
|
"""
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="openai/gpt-4o-mini",
|
||||||
|
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||||
|
api_key="very-bad-api-key",
|
||||||
|
fallbacks=[{"api_key": os.getenv("OPENAI_API_KEY")}],
|
||||||
|
)
|
||||||
|
assert response is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_acompletion_fallbacks_empty_list():
|
||||||
|
"""
|
||||||
|
Test behavior when fallbacks list is empty
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="openai/unknown-model",
|
||||||
|
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||||
|
fallbacks=[],
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
assert isinstance(e, litellm.NotFoundError)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_acompletion_fallbacks_none_response():
|
||||||
|
"""
|
||||||
|
Test handling when a fallback model returns None
|
||||||
|
Should continue to next fallback rather than returning None
|
||||||
|
"""
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="openai/unknown-model",
|
||||||
|
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||||
|
fallbacks=["gpt-3.5-turbo"], # replace with a model you know works
|
||||||
|
)
|
||||||
|
assert response is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_completion_fallbacks_sync():
|
||||||
|
response = litellm.completion(
|
||||||
|
model="openai/unknown-model",
|
||||||
|
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||||
|
fallbacks=["openai/gpt-4o-mini"],
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
assert response is not None
|
Loading…
Add table
Add a link
Reference in a new issue