forked from phoenix/litellm-mirror
fix(router.py): only return 'max_tokens', 'input_cost_per_token', etc. in 'get_router_model_info' if base_model is set
This commit is contained in:
parent
a7122f91a1
commit
aa6f7665c4
2 changed files with 137 additions and 6 deletions
|
@ -105,7 +105,9 @@ class Router:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model_list: Optional[List[Union[DeploymentTypedDict, Dict]]] = None,
|
||||
model_list: Optional[
|
||||
Union[List[DeploymentTypedDict], List[dict[str, Any]], List[Dict[str, Any]]]
|
||||
] = None,
|
||||
## ASSISTANTS API ##
|
||||
assistants_config: Optional[AssistantsTypedDict] = None,
|
||||
## CACHING ##
|
||||
|
@ -3970,16 +3972,36 @@ class Router:
|
|||
|
||||
Augment litellm info with additional params set in `model_info`.
|
||||
|
||||
For azure models, ignore the `model:`. Only set max tokens, cost values if base_model is set.
|
||||
|
||||
Returns
|
||||
- ModelInfo - If found -> typed dict with max tokens, input cost, etc.
|
||||
|
||||
Raises:
|
||||
- ValueError -> If model is not mapped yet
|
||||
"""
|
||||
## SET MODEL NAME
|
||||
## GET BASE MODEL
|
||||
base_model = deployment.get("model_info", {}).get("base_model", None)
|
||||
if base_model is None:
|
||||
base_model = deployment.get("litellm_params", {}).get("base_model", None)
|
||||
model = base_model or deployment.get("litellm_params", {}).get("model", None)
|
||||
|
||||
## GET LITELLM MODEL INFO
|
||||
model = base_model
|
||||
|
||||
## GET PROVIDER
|
||||
_model, custom_llm_provider, _, _ = litellm.get_llm_provider(
|
||||
model=deployment.get("litellm_params", {}).get("model", ""),
|
||||
litellm_params=LiteLLM_Params(**deployment.get("litellm_params", {})),
|
||||
)
|
||||
|
||||
## SET MODEL TO 'model=' - if base_model is None + not azure
|
||||
if custom_llm_provider == "azure" and base_model is None:
|
||||
verbose_router_logger.error(
|
||||
"Could not identify azure model. Set azure 'base_model' for accurate max tokens, cost tracking, etc.- https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models"
|
||||
)
|
||||
else:
|
||||
model = deployment.get("litellm_params", {}).get("model", None)
|
||||
|
||||
## GET LITELLM MODEL INFO - raises exception, if model is not mapped
|
||||
model_info = litellm.get_model_info(model=model)
|
||||
|
||||
## CHECK USER SET MODEL INFO
|
||||
|
@ -4365,7 +4387,7 @@ class Router:
|
|||
"""
|
||||
Filter out model in model group, if:
|
||||
|
||||
- model context window < message length
|
||||
- model context window < message length. For azure openai models, requires 'base_model' is set. - https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models
|
||||
- filter models above rpm limits
|
||||
- if region given, filter out models not in that region / unknown region
|
||||
- [TODO] function call and model doesn't support function calling
|
||||
|
@ -4382,6 +4404,11 @@ class Router:
|
|||
try:
|
||||
input_tokens = litellm.token_counter(messages=messages)
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(
|
||||
"litellm.router.py::_pre_call_checks: failed to count tokens. Returning initial list of deployments. Got - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
return _returned_deployments
|
||||
|
||||
_context_window_error = False
|
||||
|
@ -4425,7 +4452,7 @@ class Router:
|
|||
)
|
||||
continue
|
||||
except Exception as e:
|
||||
verbose_router_logger.debug("An error occurs - {}".format(str(e)))
|
||||
verbose_router_logger.error("An error occurs - {}".format(str(e)))
|
||||
|
||||
_litellm_params = deployment.get("litellm_params", {})
|
||||
model_id = deployment.get("model_info", {}).get("id", "")
|
||||
|
|
|
@ -16,6 +16,7 @@ sys.path.insert(
|
|||
import os
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
|
@ -1884,3 +1885,106 @@ async def test_router_model_usage(mock_response):
|
|||
else:
|
||||
print(f"allowed_fails: {allowed_fails}")
|
||||
raise e
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, base_model, llm_provider",
|
||||
[
|
||||
("azure/gpt-4", None, "azure"),
|
||||
("azure/gpt-4", "azure/gpt-4-0125-preview", "azure"),
|
||||
("gpt-4", None, "openai"),
|
||||
],
|
||||
)
|
||||
def test_router_get_model_info(model, base_model, llm_provider):
|
||||
"""
|
||||
Test if router get model info works based on provider
|
||||
|
||||
For azure -> only if base model set
|
||||
For openai -> use model=
|
||||
"""
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-4",
|
||||
"litellm_params": {
|
||||
"model": model,
|
||||
"api_key": "my-fake-key",
|
||||
"api_base": "my-fake-base",
|
||||
},
|
||||
"model_info": {"base_model": base_model, "id": "1"},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
deployment = router.get_deployment(model_id="1")
|
||||
|
||||
assert deployment is not None
|
||||
|
||||
if llm_provider == "openai" or (base_model is not None and llm_provider == "azure"):
|
||||
router.get_router_model_info(deployment=deployment.to_json())
|
||||
else:
|
||||
try:
|
||||
router.get_router_model_info(deployment=deployment.to_json())
|
||||
pytest.fail("Expected this to raise model not mapped error")
|
||||
except Exception as e:
|
||||
if "This model isn't mapped yet" in str(e):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, base_model, llm_provider",
|
||||
[
|
||||
("azure/gpt-4", None, "azure"),
|
||||
("azure/gpt-4", "azure/gpt-4-0125-preview", "azure"),
|
||||
("gpt-4", None, "openai"),
|
||||
],
|
||||
)
|
||||
def test_router_context_window_pre_call_check(model, base_model, llm_provider):
|
||||
"""
|
||||
- For an azure model
|
||||
- if no base model set
|
||||
- don't enforce context window limits
|
||||
"""
|
||||
try:
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-4",
|
||||
"litellm_params": {
|
||||
"model": model,
|
||||
"api_key": "my-fake-key",
|
||||
"api_base": "my-fake-base",
|
||||
},
|
||||
"model_info": {"base_model": base_model, "id": "1"},
|
||||
}
|
||||
]
|
||||
router = Router(
|
||||
model_list=model_list,
|
||||
set_verbose=True,
|
||||
enable_pre_call_checks=True,
|
||||
num_retries=0,
|
||||
)
|
||||
|
||||
litellm.token_counter = MagicMock()
|
||||
|
||||
def token_counter_side_effect(*args, **kwargs):
|
||||
# Process args and kwargs if needed
|
||||
return 1000000
|
||||
|
||||
litellm.token_counter.side_effect = token_counter_side_effect
|
||||
try:
|
||||
updated_list = router._pre_call_checks(
|
||||
model="gpt-4",
|
||||
healthy_deployments=model_list,
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
)
|
||||
if llm_provider == "azure" and base_model is None:
|
||||
assert len(updated_list) == 1
|
||||
else:
|
||||
pytest.fail("Expected to raise an error. Got={}".format(updated_list))
|
||||
except Exception as e:
|
||||
if (
|
||||
llm_provider == "azure" and base_model is not None
|
||||
) or llm_provider == "openai":
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Got unexpected exception on router! - {str(e)}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue