fix(router.py): only return 'max_tokens', 'input_cost_per_token', etc. in 'get_router_model_info' if base_model is set

This commit is contained in:
Krrish Dholakia 2024-06-26 16:02:23 -07:00
parent a7122f91a1
commit aa6f7665c4
2 changed files with 137 additions and 6 deletions

View file

@ -105,7 +105,9 @@ class Router:
def __init__( def __init__(
self, self,
model_list: Optional[List[Union[DeploymentTypedDict, Dict]]] = None, model_list: Optional[
Union[List[DeploymentTypedDict], List[dict[str, Any]], List[Dict[str, Any]]]
] = None,
## ASSISTANTS API ## ## ASSISTANTS API ##
assistants_config: Optional[AssistantsTypedDict] = None, assistants_config: Optional[AssistantsTypedDict] = None,
## CACHING ## ## CACHING ##
@ -3970,16 +3972,36 @@ class Router:
Augment litellm info with additional params set in `model_info`. Augment litellm info with additional params set in `model_info`.
For azure models, ignore the `model:`. Only set max tokens, cost values if base_model is set.
Returns Returns
- ModelInfo - If found -> typed dict with max tokens, input cost, etc. - ModelInfo - If found -> typed dict with max tokens, input cost, etc.
Raises:
- ValueError -> If model is not mapped yet
""" """
## SET MODEL NAME ## GET BASE MODEL
base_model = deployment.get("model_info", {}).get("base_model", None) base_model = deployment.get("model_info", {}).get("base_model", None)
if base_model is None: if base_model is None:
base_model = deployment.get("litellm_params", {}).get("base_model", None) base_model = deployment.get("litellm_params", {}).get("base_model", None)
model = base_model or deployment.get("litellm_params", {}).get("model", None)
## GET LITELLM MODEL INFO model = base_model
## GET PROVIDER
_model, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=deployment.get("litellm_params", {}).get("model", ""),
litellm_params=LiteLLM_Params(**deployment.get("litellm_params", {})),
)
## SET MODEL TO 'model=' - if base_model is None + not azure
if custom_llm_provider == "azure" and base_model is None:
verbose_router_logger.error(
"Could not identify azure model. Set azure 'base_model' for accurate max tokens, cost tracking, etc.- https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models"
)
else:
model = deployment.get("litellm_params", {}).get("model", None)
## GET LITELLM MODEL INFO - raises exception, if model is not mapped
model_info = litellm.get_model_info(model=model) model_info = litellm.get_model_info(model=model)
## CHECK USER SET MODEL INFO ## CHECK USER SET MODEL INFO
@ -4365,7 +4387,7 @@ class Router:
""" """
Filter out model in model group, if: Filter out model in model group, if:
- model context window < message length - model context window < message length. For azure openai models, requires 'base_model' is set. - https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models
- filter models above rpm limits - filter models above rpm limits
- if region given, filter out models not in that region / unknown region - if region given, filter out models not in that region / unknown region
- [TODO] function call and model doesn't support function calling - [TODO] function call and model doesn't support function calling
@ -4382,6 +4404,11 @@ class Router:
try: try:
input_tokens = litellm.token_counter(messages=messages) input_tokens = litellm.token_counter(messages=messages)
except Exception as e: except Exception as e:
verbose_router_logger.error(
"litellm.router.py::_pre_call_checks: failed to count tokens. Returning initial list of deployments. Got - {}".format(
str(e)
)
)
return _returned_deployments return _returned_deployments
_context_window_error = False _context_window_error = False
@ -4425,7 +4452,7 @@ class Router:
) )
continue continue
except Exception as e: except Exception as e:
verbose_router_logger.debug("An error occurs - {}".format(str(e))) verbose_router_logger.error("An error occurs - {}".format(str(e)))
_litellm_params = deployment.get("litellm_params", {}) _litellm_params = deployment.get("litellm_params", {})
model_id = deployment.get("model_info", {}).get("id", "") model_id = deployment.get("model_info", {}).get("id", "")

View file

@ -16,6 +16,7 @@ sys.path.insert(
import os import os
from collections import defaultdict from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from unittest.mock import AsyncMock, MagicMock, patch
import httpx import httpx
from dotenv import load_dotenv from dotenv import load_dotenv
@ -1884,3 +1885,106 @@ async def test_router_model_usage(mock_response):
else: else:
print(f"allowed_fails: {allowed_fails}") print(f"allowed_fails: {allowed_fails}")
raise e raise e
@pytest.mark.parametrize(
"model, base_model, llm_provider",
[
("azure/gpt-4", None, "azure"),
("azure/gpt-4", "azure/gpt-4-0125-preview", "azure"),
("gpt-4", None, "openai"),
],
)
def test_router_get_model_info(model, base_model, llm_provider):
"""
Test if router get model info works based on provider
For azure -> only if base model set
For openai -> use model=
"""
router = Router(
model_list=[
{
"model_name": "gpt-4",
"litellm_params": {
"model": model,
"api_key": "my-fake-key",
"api_base": "my-fake-base",
},
"model_info": {"base_model": base_model, "id": "1"},
}
]
)
deployment = router.get_deployment(model_id="1")
assert deployment is not None
if llm_provider == "openai" or (base_model is not None and llm_provider == "azure"):
router.get_router_model_info(deployment=deployment.to_json())
else:
try:
router.get_router_model_info(deployment=deployment.to_json())
pytest.fail("Expected this to raise model not mapped error")
except Exception as e:
if "This model isn't mapped yet" in str(e):
pass
@pytest.mark.parametrize(
"model, base_model, llm_provider",
[
("azure/gpt-4", None, "azure"),
("azure/gpt-4", "azure/gpt-4-0125-preview", "azure"),
("gpt-4", None, "openai"),
],
)
def test_router_context_window_pre_call_check(model, base_model, llm_provider):
"""
- For an azure model
- if no base model set
- don't enforce context window limits
"""
try:
model_list = [
{
"model_name": "gpt-4",
"litellm_params": {
"model": model,
"api_key": "my-fake-key",
"api_base": "my-fake-base",
},
"model_info": {"base_model": base_model, "id": "1"},
}
]
router = Router(
model_list=model_list,
set_verbose=True,
enable_pre_call_checks=True,
num_retries=0,
)
litellm.token_counter = MagicMock()
def token_counter_side_effect(*args, **kwargs):
# Process args and kwargs if needed
return 1000000
litellm.token_counter.side_effect = token_counter_side_effect
try:
updated_list = router._pre_call_checks(
model="gpt-4",
healthy_deployments=model_list,
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
if llm_provider == "azure" and base_model is None:
assert len(updated_list) == 1
else:
pytest.fail("Expected to raise an error. Got={}".format(updated_list))
except Exception as e:
if (
llm_provider == "azure" and base_model is not None
) or llm_provider == "openai":
pass
except Exception as e:
pytest.fail(f"Got unexpected exception on router! - {str(e)}")