forked from phoenix/litellm-mirror
fix(router.py): only return 'max_tokens', 'input_cost_per_token', etc. in 'get_router_model_info' if base_model is set
This commit is contained in:
parent
a7122f91a1
commit
aa6f7665c4
2 changed files with 137 additions and 6 deletions
|
@ -105,7 +105,9 @@ class Router:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_list: Optional[List[Union[DeploymentTypedDict, Dict]]] = None,
|
model_list: Optional[
|
||||||
|
Union[List[DeploymentTypedDict], List[dict[str, Any]], List[Dict[str, Any]]]
|
||||||
|
] = None,
|
||||||
## ASSISTANTS API ##
|
## ASSISTANTS API ##
|
||||||
assistants_config: Optional[AssistantsTypedDict] = None,
|
assistants_config: Optional[AssistantsTypedDict] = None,
|
||||||
## CACHING ##
|
## CACHING ##
|
||||||
|
@ -3970,16 +3972,36 @@ class Router:
|
||||||
|
|
||||||
Augment litellm info with additional params set in `model_info`.
|
Augment litellm info with additional params set in `model_info`.
|
||||||
|
|
||||||
|
For azure models, ignore the `model:`. Only set max tokens, cost values if base_model is set.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
- ModelInfo - If found -> typed dict with max tokens, input cost, etc.
|
- ModelInfo - If found -> typed dict with max tokens, input cost, etc.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
- ValueError -> If model is not mapped yet
|
||||||
"""
|
"""
|
||||||
## SET MODEL NAME
|
## GET BASE MODEL
|
||||||
base_model = deployment.get("model_info", {}).get("base_model", None)
|
base_model = deployment.get("model_info", {}).get("base_model", None)
|
||||||
if base_model is None:
|
if base_model is None:
|
||||||
base_model = deployment.get("litellm_params", {}).get("base_model", None)
|
base_model = deployment.get("litellm_params", {}).get("base_model", None)
|
||||||
model = base_model or deployment.get("litellm_params", {}).get("model", None)
|
|
||||||
|
|
||||||
## GET LITELLM MODEL INFO
|
model = base_model
|
||||||
|
|
||||||
|
## GET PROVIDER
|
||||||
|
_model, custom_llm_provider, _, _ = litellm.get_llm_provider(
|
||||||
|
model=deployment.get("litellm_params", {}).get("model", ""),
|
||||||
|
litellm_params=LiteLLM_Params(**deployment.get("litellm_params", {})),
|
||||||
|
)
|
||||||
|
|
||||||
|
## SET MODEL TO 'model=' - if base_model is None + not azure
|
||||||
|
if custom_llm_provider == "azure" and base_model is None:
|
||||||
|
verbose_router_logger.error(
|
||||||
|
"Could not identify azure model. Set azure 'base_model' for accurate max tokens, cost tracking, etc.- https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model = deployment.get("litellm_params", {}).get("model", None)
|
||||||
|
|
||||||
|
## GET LITELLM MODEL INFO - raises exception, if model is not mapped
|
||||||
model_info = litellm.get_model_info(model=model)
|
model_info = litellm.get_model_info(model=model)
|
||||||
|
|
||||||
## CHECK USER SET MODEL INFO
|
## CHECK USER SET MODEL INFO
|
||||||
|
@ -4365,7 +4387,7 @@ class Router:
|
||||||
"""
|
"""
|
||||||
Filter out model in model group, if:
|
Filter out model in model group, if:
|
||||||
|
|
||||||
- model context window < message length
|
- model context window < message length. For azure openai models, requires 'base_model' is set. - https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models
|
||||||
- filter models above rpm limits
|
- filter models above rpm limits
|
||||||
- if region given, filter out models not in that region / unknown region
|
- if region given, filter out models not in that region / unknown region
|
||||||
- [TODO] function call and model doesn't support function calling
|
- [TODO] function call and model doesn't support function calling
|
||||||
|
@ -4382,6 +4404,11 @@ class Router:
|
||||||
try:
|
try:
|
||||||
input_tokens = litellm.token_counter(messages=messages)
|
input_tokens = litellm.token_counter(messages=messages)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
verbose_router_logger.error(
|
||||||
|
"litellm.router.py::_pre_call_checks: failed to count tokens. Returning initial list of deployments. Got - {}".format(
|
||||||
|
str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
return _returned_deployments
|
return _returned_deployments
|
||||||
|
|
||||||
_context_window_error = False
|
_context_window_error = False
|
||||||
|
@ -4425,7 +4452,7 @@ class Router:
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_router_logger.debug("An error occurs - {}".format(str(e)))
|
verbose_router_logger.error("An error occurs - {}".format(str(e)))
|
||||||
|
|
||||||
_litellm_params = deployment.get("litellm_params", {})
|
_litellm_params = deployment.get("litellm_params", {})
|
||||||
model_id = deployment.get("model_info", {}).get("id", "")
|
model_id = deployment.get("model_info", {}).get("id", "")
|
||||||
|
|
|
@ -16,6 +16,7 @@ sys.path.insert(
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
@ -1884,3 +1885,106 @@ async def test_router_model_usage(mock_response):
|
||||||
else:
|
else:
|
||||||
print(f"allowed_fails: {allowed_fails}")
|
print(f"allowed_fails: {allowed_fails}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model, base_model, llm_provider",
|
||||||
|
[
|
||||||
|
("azure/gpt-4", None, "azure"),
|
||||||
|
("azure/gpt-4", "azure/gpt-4-0125-preview", "azure"),
|
||||||
|
("gpt-4", None, "openai"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_router_get_model_info(model, base_model, llm_provider):
|
||||||
|
"""
|
||||||
|
Test if router get model info works based on provider
|
||||||
|
|
||||||
|
For azure -> only if base model set
|
||||||
|
For openai -> use model=
|
||||||
|
"""
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-4",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": model,
|
||||||
|
"api_key": "my-fake-key",
|
||||||
|
"api_base": "my-fake-base",
|
||||||
|
},
|
||||||
|
"model_info": {"base_model": base_model, "id": "1"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
deployment = router.get_deployment(model_id="1")
|
||||||
|
|
||||||
|
assert deployment is not None
|
||||||
|
|
||||||
|
if llm_provider == "openai" or (base_model is not None and llm_provider == "azure"):
|
||||||
|
router.get_router_model_info(deployment=deployment.to_json())
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
router.get_router_model_info(deployment=deployment.to_json())
|
||||||
|
pytest.fail("Expected this to raise model not mapped error")
|
||||||
|
except Exception as e:
|
||||||
|
if "This model isn't mapped yet" in str(e):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model, base_model, llm_provider",
|
||||||
|
[
|
||||||
|
("azure/gpt-4", None, "azure"),
|
||||||
|
("azure/gpt-4", "azure/gpt-4-0125-preview", "azure"),
|
||||||
|
("gpt-4", None, "openai"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_router_context_window_pre_call_check(model, base_model, llm_provider):
|
||||||
|
"""
|
||||||
|
- For an azure model
|
||||||
|
- if no base model set
|
||||||
|
- don't enforce context window limits
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-4",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": model,
|
||||||
|
"api_key": "my-fake-key",
|
||||||
|
"api_base": "my-fake-base",
|
||||||
|
},
|
||||||
|
"model_info": {"base_model": base_model, "id": "1"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
router = Router(
|
||||||
|
model_list=model_list,
|
||||||
|
set_verbose=True,
|
||||||
|
enable_pre_call_checks=True,
|
||||||
|
num_retries=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
litellm.token_counter = MagicMock()
|
||||||
|
|
||||||
|
def token_counter_side_effect(*args, **kwargs):
|
||||||
|
# Process args and kwargs if needed
|
||||||
|
return 1000000
|
||||||
|
|
||||||
|
litellm.token_counter.side_effect = token_counter_side_effect
|
||||||
|
try:
|
||||||
|
updated_list = router._pre_call_checks(
|
||||||
|
model="gpt-4",
|
||||||
|
healthy_deployments=model_list,
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
)
|
||||||
|
if llm_provider == "azure" and base_model is None:
|
||||||
|
assert len(updated_list) == 1
|
||||||
|
else:
|
||||||
|
pytest.fail("Expected to raise an error. Got={}".format(updated_list))
|
||||||
|
except Exception as e:
|
||||||
|
if (
|
||||||
|
llm_provider == "azure" and base_model is not None
|
||||||
|
) or llm_provider == "openai":
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Got unexpected exception on router! - {str(e)}")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue