Controll fallback prompts client-side (#7334)

* feat(router.py): support passing model-specific messages in fallbacks

* docs(routing.md): separate router timeouts into separate doc

allow for 1 fallbacks doc (across proxy/router)

* docs(routing.md): cleanup router docs

* docs(reliability.md): cleanup docs

* docs(reliability.md): cleaned up fallback doc

just have 1 doc across sdk/proxy

simplifies docs

* docs(reliability.md): add setting model-specific fallback prompts

* fix: fix linting errors

* test: skip test causing openai rate limit errros

* test: fix test

* test: run vertex test first to catch error
This commit is contained in:
Krish Dholakia 2024-12-20 19:09:53 -08:00 committed by GitHub
parent 4322954dc6
commit 70a9ea99f2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 861 additions and 553 deletions

View file

@ -69,6 +69,7 @@ from litellm.router_utils.cooldown_handlers import (
_set_cooldown_deployments,
)
from litellm.router_utils.fallback_event_handlers import (
_check_non_standard_fallback_format,
get_fallback_model_group,
run_async_fallback,
)
@ -2647,6 +2648,27 @@ class Router:
try:
verbose_router_logger.info("Trying to fallback b/w models")
# check if client-side fallbacks are used (e.g. fallbacks = ["gpt-3.5-turbo", "claude-3-haiku"] or fallbacks=[{"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hey, how's it going?"}]}]
is_non_standard_fallback_format = _check_non_standard_fallback_format(
fallbacks=fallbacks
)
if is_non_standard_fallback_format:
input_kwargs.update(
{
"fallback_model_group": fallbacks,
"original_model_group": original_model_group,
}
)
response = await run_async_fallback(
*args,
**input_kwargs,
)
return response
if isinstance(e, litellm.ContextWindowExceededError):
if context_window_fallbacks is not None:
fallback_model_group: Optional[List[str]] = (
@ -2722,7 +2744,7 @@ class Router:
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
fallback_model_group, generic_fallback_idx = (
get_fallback_model_group(
fallbacks=fallbacks,
fallbacks=fallbacks, # if fallbacks = [{"gpt-3.5-turbo": ["claude-3-haiku"]}]
model_group=cast(str, model_group),
)
)