fix(router.py): use user-defined model_input_tokens for pre-call filter checks

This commit is contained in:
Krrish Dholakia 2024-06-24 17:25:26 -07:00
parent 123477b55a
commit f5fbdf0fee
3 changed files with 58 additions and 5 deletions

View file

@ -404,6 +404,7 @@ class Router:
litellm.failure_callback = [self.deployment_callback_on_failure]
print( # noqa
f"Intialized router with Routing strategy: {self.routing_strategy}\n\n"
f"Routing enable_pre_call_checks: {self.enable_pre_call_checks}\n\n"
f"Routing fallbacks: {self.fallbacks}\n\n"
f"Routing content fallbacks: {self.content_policy_fallbacks}\n\n"
f"Routing context window fallbacks: {self.context_window_fallbacks}\n\n"
@ -3915,9 +3916,38 @@ class Router:
raise Exception("Model invalid format - {}".format(type(model)))
return None
def get_router_model_info(self, deployment: dict) -> ModelMapInfo:
"""
For a given model id, return the model info (max tokens, input cost, output cost, etc.).
Augment litellm info with additional params set in `model_info`.
Returns
- ModelInfo - If found -> typed dict with max tokens, input cost, etc.
"""
## SET MODEL NAME
base_model = deployment.get("model_info", {}).get("base_model", None)
if base_model is None:
base_model = deployment.get("litellm_params", {}).get("base_model", None)
model = base_model or deployment.get("litellm_params", {}).get("model", None)
## GET LITELLM MODEL INFO
model_info = litellm.get_model_info(model=model)
## CHECK USER SET MODEL INFO
user_model_info = deployment.get("model_info", {})
model_info.update(user_model_info)
return model_info
def get_model_info(self, id: str) -> Optional[dict]:
"""
For a given model id, return the model info
Returns
- dict: the model in list with 'model_name', 'litellm_params', Optional['model_info']
- None: could not find deployment in list
"""
for model in self.model_list:
if "model_info" in model and "id" in model["model_info"]:
@ -4307,6 +4337,7 @@ class Router:
return _returned_deployments
_context_window_error = False
_potential_error_str = ""
_rate_limit_error = False
## get model group RPM ##
@ -4327,7 +4358,7 @@ class Router:
model = base_model or deployment.get("litellm_params", {}).get(
"model", None
)
model_info = litellm.get_model_info(model=model)
model_info = self.get_router_model_info(deployment=deployment)
if (
isinstance(model_info, dict)
@ -4339,6 +4370,11 @@ class Router:
):
invalid_model_indices.append(idx)
_context_window_error = True
_potential_error_str += (
"Model={}, Max Input Tokens={}, Got={}".format(
model, model_info["max_input_tokens"], input_tokens
)
)
continue
except Exception as e:
verbose_router_logger.debug("An error occurs - {}".format(str(e)))
@ -4440,7 +4476,9 @@ class Router:
)
elif _context_window_error == True:
raise litellm.ContextWindowExceededError(
message="Context Window exceeded for given call",
message="litellm._pre_call_checks: Context Window exceeded for given call. No models have context window large enough for this call.\n{}".format(
_potential_error_str
),
model=model,
llm_provider="",
response=httpx.Response(