fix(router.py): only return 'max_tokens', 'input_cost_per_token', etc. in 'get_router_model_info' if base_model is set

This commit is contained in:
Krrish Dholakia 2024-06-26 16:02:23 -07:00
parent 66e3a4f30e
commit 6e53de5462
2 changed files with 137 additions and 6 deletions

View file

@ -105,7 +105,9 @@ class Router:
def __init__(
self,
model_list: Optional[List[Union[DeploymentTypedDict, Dict]]] = None,
model_list: Optional[
Union[List[DeploymentTypedDict], List[dict[str, Any]], List[Dict[str, Any]]]
] = None,
## ASSISTANTS API ##
assistants_config: Optional[AssistantsTypedDict] = None,
## CACHING ##
@ -3970,16 +3972,36 @@ class Router:
Augment litellm info with additional params set in `model_info`.
For azure models, ignore the `model:`. Only set max tokens, cost values if base_model is set.
Returns
- ModelInfo - If found -> typed dict with max tokens, input cost, etc.
Raises:
- ValueError -> If model is not mapped yet
"""
## SET MODEL NAME
## GET BASE MODEL
base_model = deployment.get("model_info", {}).get("base_model", None)
if base_model is None:
base_model = deployment.get("litellm_params", {}).get("base_model", None)
model = base_model or deployment.get("litellm_params", {}).get("model", None)
## GET LITELLM MODEL INFO
model = base_model
## GET PROVIDER
_model, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=deployment.get("litellm_params", {}).get("model", ""),
litellm_params=LiteLLM_Params(**deployment.get("litellm_params", {})),
)
## SET MODEL TO 'model=' - if base_model is None + not azure
if custom_llm_provider == "azure" and base_model is None:
verbose_router_logger.error(
"Could not identify azure model. Set azure 'base_model' for accurate max tokens, cost tracking, etc.- https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models"
)
else:
model = deployment.get("litellm_params", {}).get("model", None)
## GET LITELLM MODEL INFO - raises exception, if model is not mapped
model_info = litellm.get_model_info(model=model)
## CHECK USER SET MODEL INFO
@ -4365,7 +4387,7 @@ class Router:
"""
Filter out model in model group, if:
- model context window < message length
- model context window < message length. For azure openai models, requires 'base_model' is set. - https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models
- filter models above rpm limits
- if region given, filter out models not in that region / unknown region
- [TODO] function call and model doesn't support function calling
@ -4382,6 +4404,11 @@ class Router:
try:
input_tokens = litellm.token_counter(messages=messages)
except Exception as e:
verbose_router_logger.error(
"litellm.router.py::_pre_call_checks: failed to count tokens. Returning initial list of deployments. Got - {}".format(
str(e)
)
)
return _returned_deployments
_context_window_error = False
@ -4425,7 +4452,7 @@ class Router:
)
continue
except Exception as e:
verbose_router_logger.debug("An error occurs - {}".format(str(e)))
verbose_router_logger.error("An error occurs - {}".format(str(e)))
_litellm_params = deployment.get("litellm_params", {})
model_id = deployment.get("model_info", {}).get("id", "")