forked from phoenix/litellm-mirror
fix(router.py): only return 'max_tokens', 'input_cost_per_token', etc. in 'get_router_model_info' if base_model is set
This commit is contained in:
parent
a7122f91a1
commit
aa6f7665c4
2 changed files with 137 additions and 6 deletions
|
@ -105,7 +105,9 @@ class Router:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model_list: Optional[List[Union[DeploymentTypedDict, Dict]]] = None,
|
||||
model_list: Optional[
|
||||
Union[List[DeploymentTypedDict], List[dict[str, Any]], List[Dict[str, Any]]]
|
||||
] = None,
|
||||
## ASSISTANTS API ##
|
||||
assistants_config: Optional[AssistantsTypedDict] = None,
|
||||
## CACHING ##
|
||||
|
@ -3970,16 +3972,36 @@ class Router:
|
|||
|
||||
Augment litellm info with additional params set in `model_info`.
|
||||
|
||||
For azure models, ignore the `model:`. Only set max tokens, cost values if base_model is set.
|
||||
|
||||
Returns
|
||||
- ModelInfo - If found -> typed dict with max tokens, input cost, etc.
|
||||
|
||||
Raises:
|
||||
- ValueError -> If model is not mapped yet
|
||||
"""
|
||||
## SET MODEL NAME
|
||||
## GET BASE MODEL
|
||||
base_model = deployment.get("model_info", {}).get("base_model", None)
|
||||
if base_model is None:
|
||||
base_model = deployment.get("litellm_params", {}).get("base_model", None)
|
||||
model = base_model or deployment.get("litellm_params", {}).get("model", None)
|
||||
|
||||
## GET LITELLM MODEL INFO
|
||||
model = base_model
|
||||
|
||||
## GET PROVIDER
|
||||
_model, custom_llm_provider, _, _ = litellm.get_llm_provider(
|
||||
model=deployment.get("litellm_params", {}).get("model", ""),
|
||||
litellm_params=LiteLLM_Params(**deployment.get("litellm_params", {})),
|
||||
)
|
||||
|
||||
## SET MODEL TO 'model=' - if base_model is None + not azure
|
||||
if custom_llm_provider == "azure" and base_model is None:
|
||||
verbose_router_logger.error(
|
||||
"Could not identify azure model. Set azure 'base_model' for accurate max tokens, cost tracking, etc.- https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models"
|
||||
)
|
||||
else:
|
||||
model = deployment.get("litellm_params", {}).get("model", None)
|
||||
|
||||
## GET LITELLM MODEL INFO - raises exception, if model is not mapped
|
||||
model_info = litellm.get_model_info(model=model)
|
||||
|
||||
## CHECK USER SET MODEL INFO
|
||||
|
@ -4365,7 +4387,7 @@ class Router:
|
|||
"""
|
||||
Filter out model in model group, if:
|
||||
|
||||
- model context window < message length
|
||||
- model context window < message length. For azure openai models, requires 'base_model' is set. - https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models
|
||||
- filter models above rpm limits
|
||||
- if region given, filter out models not in that region / unknown region
|
||||
- [TODO] function call and model doesn't support function calling
|
||||
|
@ -4382,6 +4404,11 @@ class Router:
|
|||
try:
|
||||
input_tokens = litellm.token_counter(messages=messages)
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(
|
||||
"litellm.router.py::_pre_call_checks: failed to count tokens. Returning initial list of deployments. Got - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
return _returned_deployments
|
||||
|
||||
_context_window_error = False
|
||||
|
@ -4425,7 +4452,7 @@ class Router:
|
|||
)
|
||||
continue
|
||||
except Exception as e:
|
||||
verbose_router_logger.debug("An error occurs - {}".format(str(e)))
|
||||
verbose_router_logger.error("An error occurs - {}".format(str(e)))
|
||||
|
||||
_litellm_params = deployment.get("litellm_params", {})
|
||||
model_id = deployment.get("model_info", {}).get("id", "")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue