(core sdk fix) - fix fallbacks stuck in infinite loop (#7751)

* test_acompletion_fallbacks_basic

* use common run_async_function

* fix completion_with_fallbacks

* fix completion with fallbacks

* fix fallback utils

* test_acompletion_fallbacks_basic

* test_completion_fallbacks_sync

* huggingface/mistralai/Mistral-7B-Instruct-v0.3
This commit is contained in:
Ishaan Jaff 2025-01-13 19:34:34 -08:00 committed by GitHub
parent 970e9c7507
commit f1335362cf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 222 additions and 156 deletions

View file

@ -46,6 +46,7 @@ from litellm import get_secret_str
from litellm._logging import verbose_router_logger
from litellm.caching.caching import DualCache, InMemoryCache, RedisCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.asyncify import run_async_function
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
@ -3264,32 +3265,7 @@ class Router:
Wrapped to reduce code duplication and prevent bugs.
"""
from concurrent.futures import ThreadPoolExecutor
def run_in_new_loop():
"""Run the coroutine in a new event loop within this thread."""
new_loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(new_loop)
return new_loop.run_until_complete(
self.async_function_with_fallbacks(*args, **kwargs)
)
finally:
new_loop.close()
asyncio.set_event_loop(None)
try:
# First, try to get the current event loop
_ = asyncio.get_running_loop()
# If we're already in an event loop, run in a separate thread
# to avoid nested event loop issues
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_in_new_loop)
return future.result()
except RuntimeError:
# No running event loop, we can safely run in this thread
return run_in_new_loop()
return run_async_function(self.async_function_with_fallbacks, *args, **kwargs)
def _get_fallback_model_group_from_fallbacks(
self,