Controll fallback prompts client-side (#7334)

* feat(router.py): support passing model-specific messages in fallbacks

* docs(routing.md): separate router timeouts into separate doc

allow for 1 fallbacks doc (across proxy/router)

* docs(routing.md): cleanup router docs

* docs(reliability.md): cleanup docs

* docs(reliability.md): cleaned up fallback doc

just have 1 doc across sdk/proxy

simplifies docs

* docs(reliability.md): add setting model-specific fallback prompts

* fix: fix linting errors

* test: skip test causing openai rate limit errros

* test: fix test

* test: run vertex test first to catch error
This commit is contained in:
Krish Dholakia 2024-12-20 19:09:53 -08:00 committed by GitHub
parent 4322954dc6
commit 70a9ea99f2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 861 additions and 553 deletions

View file

@ -1,9 +1,10 @@
from enum import Enum
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import litellm
from litellm._logging import verbose_router_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.router import LiteLLMParamsTypedDict
if TYPE_CHECKING:
from litellm.router import Router as _Router
@ -67,7 +68,7 @@ def get_fallback_model_group(
elif list(item.keys())[0] == "*": # check generic fallback
generic_fallback_idx = idx
elif isinstance(item, str):
fallback_model_group = [fallbacks.pop(idx)]
fallback_model_group = [fallbacks.pop(idx)] # returns single-item list
## if none, check for generic fallback
if fallback_model_group is None:
if stripped_model_fallback is not None:
@ -122,9 +123,12 @@ async def run_async_fallback(
# LOGGING
kwargs = litellm_router.log_retry(kwargs=kwargs, e=original_exception)
verbose_router_logger.info(f"Falling back to model_group = {mg}")
kwargs["model"] = mg
if isinstance(mg, str):
kwargs["model"] = mg
elif isinstance(mg, dict):
kwargs.update(mg)
kwargs.setdefault("metadata", {}).update(
{"model_group": mg}
{"model_group": kwargs.get("model", None)}
) # update model_group used, if fallbacks are done
kwargs["fallback_depth"] = fallback_depth + 1
kwargs["max_fallbacks"] = max_fallbacks
@ -310,3 +314,31 @@ async def log_failure_fallback_event(
verbose_router_logger.error(
f"Error in log_failure_fallback_event: {str(e)}"
)
def _check_non_standard_fallback_format(fallbacks: Optional[List[Any]]) -> bool:
"""
Checks if the fallbacks list is a list of strings or a list of dictionaries.
If
- List[str]: e.g. ["claude-3-haiku", "openai/o-1"]
- List[Dict[<LiteLLMParamsTypedDict>, Any]]: e.g. [{"model": "claude-3-haiku", "messages": [{"role": "user", "content": "Hey, how's it going?"}]}]
If [{"gpt-3.5-turbo": ["claude-3-haiku"]}] then standard format.
"""
if fallbacks is None or not isinstance(fallbacks, list) or len(fallbacks) == 0:
return False
if all(isinstance(item, str) for item in fallbacks):
return True
elif all(isinstance(item, dict) for item in fallbacks):
for key in LiteLLMParamsTypedDict.__annotations__.keys():
if key in fallbacks[0].keys():
return True
return False
def run_non_standard_fallback_format(
fallbacks: Union[List[str], List[Dict[str, Any]]], model_group: str
):
pass