feat(router.py): enable pre-call checks

filter models outside of context window limits of a given message for a model group

 https://github.com/BerriAI/litellm/issues/872
This commit is contained in:
Krrish Dholakia 2024-03-23 18:03:30 -07:00
parent 2fabff06c0
commit eb3ca85d7e
7 changed files with 3417 additions and 526 deletions

View file

@ -98,6 +98,7 @@ class Router:
fallbacks: List = [],
context_window_fallbacks: List = [],
model_group_alias: Optional[dict] = {},
enable_pre_call_checks: bool = False,
retry_after: int = 0, # min time to wait before retrying a failed request
allowed_fails: Optional[
int
@ -131,6 +132,7 @@ class Router:
debug_level (Literal["DEBUG", "INFO"]): Debug level for logging. Defaults to "INFO".
fallbacks (List): List of fallback options. Defaults to [].
context_window_fallbacks (List): List of context window fallback options. Defaults to [].
enable_pre_call_checks (boolean): Filter out deployments which are outside context window limits for a given prompt
model_group_alias (Optional[dict]): Alias for model groups. Defaults to {}.
retry_after (int): Minimum time to wait before retrying a failed request. Defaults to 0.
allowed_fails (Optional[int]): Number of allowed fails before adding to cooldown. Defaults to None.
@ -143,6 +145,7 @@ class Router:
"""
self.set_verbose = set_verbose
self.debug_level = debug_level
self.enable_pre_call_checks = enable_pre_call_checks
if self.set_verbose == True:
if debug_level == "INFO":
verbose_router_logger.setLevel(logging.INFO)
@ -2150,6 +2153,54 @@ class Router:
client = self.cache.get_cache(key=cache_key)
return client
def _pre_call_checks(
self,
model: str,
healthy_deployments: List,
messages: List[Dict[str, str]],
):
"""
Filter out model in model group, if:
- model context window < message length
- function call and model doesn't support function calling
"""
verbose_router_logger.debug(
f"Starting Pre-call checks for deployments in model={model}"
)
_returned_deployments = copy.deepcopy(healthy_deployments)
invalid_model_indices = []
try:
input_tokens = litellm.token_counter(messages=messages)
except:
return _returned_deployments
for idx, deployment in enumerate(_returned_deployments):
# see if we have the info for this model
try:
model_info = litellm.get_model_info(model=deployment["model_name"])
except:
continue
if (
isinstance(model_info, dict)
and model_info.get("max_input_tokens", None) is not None
):
if (
isinstance(model_info["max_input_tokens"], int)
and input_tokens > model_info["max_input_tokens"]
):
invalid_model_indices.append(idx)
if len(invalid_model_indices) > 0:
for idx in reversed(invalid_model_indices):
_returned_deployments.pop(idx)
return _returned_deployments
def get_available_deployment(
self,
model: str,
@ -2209,6 +2260,12 @@ class Router:
for deployment in deployments_to_remove:
healthy_deployments.remove(deployment)
# filter pre-call checks
if self.enable_pre_call_checks and messages is not None:
healthy_deployments = self._pre_call_checks(
model=model, healthy_deployments=healthy_deployments, messages=messages
)
verbose_router_logger.debug(
f"healthy deployments: length {len(healthy_deployments)} {healthy_deployments}"
)