mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
feat(router.py): enable pre-call checks
filter models outside of context window limits of a given message for a model group https://github.com/BerriAI/litellm/issues/872
This commit is contained in:
parent
2fabff06c0
commit
eb3ca85d7e
7 changed files with 3417 additions and 526 deletions
|
@ -98,6 +98,7 @@ class Router:
|
|||
fallbacks: List = [],
|
||||
context_window_fallbacks: List = [],
|
||||
model_group_alias: Optional[dict] = {},
|
||||
enable_pre_call_checks: bool = False,
|
||||
retry_after: int = 0, # min time to wait before retrying a failed request
|
||||
allowed_fails: Optional[
|
||||
int
|
||||
|
@ -131,6 +132,7 @@ class Router:
|
|||
debug_level (Literal["DEBUG", "INFO"]): Debug level for logging. Defaults to "INFO".
|
||||
fallbacks (List): List of fallback options. Defaults to [].
|
||||
context_window_fallbacks (List): List of context window fallback options. Defaults to [].
|
||||
enable_pre_call_checks (boolean): Filter out deployments which are outside context window limits for a given prompt
|
||||
model_group_alias (Optional[dict]): Alias for model groups. Defaults to {}.
|
||||
retry_after (int): Minimum time to wait before retrying a failed request. Defaults to 0.
|
||||
allowed_fails (Optional[int]): Number of allowed fails before adding to cooldown. Defaults to None.
|
||||
|
@ -143,6 +145,7 @@ class Router:
|
|||
"""
|
||||
self.set_verbose = set_verbose
|
||||
self.debug_level = debug_level
|
||||
self.enable_pre_call_checks = enable_pre_call_checks
|
||||
if self.set_verbose == True:
|
||||
if debug_level == "INFO":
|
||||
verbose_router_logger.setLevel(logging.INFO)
|
||||
|
@ -2150,6 +2153,54 @@ class Router:
|
|||
client = self.cache.get_cache(key=cache_key)
|
||||
return client
|
||||
|
||||
def _pre_call_checks(
|
||||
self,
|
||||
model: str,
|
||||
healthy_deployments: List,
|
||||
messages: List[Dict[str, str]],
|
||||
):
|
||||
"""
|
||||
Filter out model in model group, if:
|
||||
|
||||
- model context window < message length
|
||||
- function call and model doesn't support function calling
|
||||
"""
|
||||
verbose_router_logger.debug(
|
||||
f"Starting Pre-call checks for deployments in model={model}"
|
||||
)
|
||||
|
||||
_returned_deployments = copy.deepcopy(healthy_deployments)
|
||||
|
||||
invalid_model_indices = []
|
||||
|
||||
try:
|
||||
input_tokens = litellm.token_counter(messages=messages)
|
||||
except:
|
||||
return _returned_deployments
|
||||
|
||||
for idx, deployment in enumerate(_returned_deployments):
|
||||
# see if we have the info for this model
|
||||
try:
|
||||
model_info = litellm.get_model_info(model=deployment["model_name"])
|
||||
except:
|
||||
continue
|
||||
|
||||
if (
|
||||
isinstance(model_info, dict)
|
||||
and model_info.get("max_input_tokens", None) is not None
|
||||
):
|
||||
if (
|
||||
isinstance(model_info["max_input_tokens"], int)
|
||||
and input_tokens > model_info["max_input_tokens"]
|
||||
):
|
||||
invalid_model_indices.append(idx)
|
||||
|
||||
if len(invalid_model_indices) > 0:
|
||||
for idx in reversed(invalid_model_indices):
|
||||
_returned_deployments.pop(idx)
|
||||
|
||||
return _returned_deployments
|
||||
|
||||
def get_available_deployment(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -2209,6 +2260,12 @@ class Router:
|
|||
for deployment in deployments_to_remove:
|
||||
healthy_deployments.remove(deployment)
|
||||
|
||||
# filter pre-call checks
|
||||
if self.enable_pre_call_checks and messages is not None:
|
||||
healthy_deployments = self._pre_call_checks(
|
||||
model=model, healthy_deployments=healthy_deployments, messages=messages
|
||||
)
|
||||
|
||||
verbose_router_logger.debug(
|
||||
f"healthy deployments: length {len(healthy_deployments)} {healthy_deployments}"
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue