From 6e53de5462808708b5ae6c5f40934f4cf67e4200 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 26 Jun 2024 16:02:23 -0700 Subject: [PATCH] fix(router.py): only return 'max_tokens', 'input_cost_per_token', etc. in 'get_router_model_info' if base_model is set --- litellm/router.py | 39 +++++++++++-- litellm/tests/test_router.py | 104 +++++++++++++++++++++++++++++++++++ 2 files changed, 137 insertions(+), 6 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index e2f7ce8b21..d069fa9d3c 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -105,7 +105,9 @@ class Router: def __init__( self, - model_list: Optional[List[Union[DeploymentTypedDict, Dict]]] = None, + model_list: Optional[ + Union[List[DeploymentTypedDict], List[dict[str, Any]], List[Dict[str, Any]]] + ] = None, ## ASSISTANTS API ## assistants_config: Optional[AssistantsTypedDict] = None, ## CACHING ## @@ -3970,16 +3972,36 @@ class Router: Augment litellm info with additional params set in `model_info`. + For azure models, ignore the `model:`. Only set max tokens, cost values if base_model is set. + Returns - ModelInfo - If found -> typed dict with max tokens, input cost, etc. + + Raises: + - ValueError -> If model is not mapped yet """ - ## SET MODEL NAME + ## GET BASE MODEL base_model = deployment.get("model_info", {}).get("base_model", None) if base_model is None: base_model = deployment.get("litellm_params", {}).get("base_model", None) - model = base_model or deployment.get("litellm_params", {}).get("model", None) - ## GET LITELLM MODEL INFO + model = base_model + + ## GET PROVIDER + _model, custom_llm_provider, _, _ = litellm.get_llm_provider( + model=deployment.get("litellm_params", {}).get("model", ""), + litellm_params=LiteLLM_Params(**deployment.get("litellm_params", {})), + ) + + ## SET MODEL TO 'model=' - if base_model is None + not azure + if custom_llm_provider == "azure" and base_model is None: + verbose_router_logger.error( + "Could not identify azure model. Set azure 'base_model' for accurate max tokens, cost tracking, etc.- https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models" + ) + else: + model = deployment.get("litellm_params", {}).get("model", None) + + ## GET LITELLM MODEL INFO - raises exception, if model is not mapped model_info = litellm.get_model_info(model=model) ## CHECK USER SET MODEL INFO @@ -4365,7 +4387,7 @@ class Router: """ Filter out model in model group, if: - - model context window < message length + - model context window < message length. For azure openai models, requires 'base_model' is set. - https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models - filter models above rpm limits - if region given, filter out models not in that region / unknown region - [TODO] function call and model doesn't support function calling @@ -4382,6 +4404,11 @@ class Router: try: input_tokens = litellm.token_counter(messages=messages) except Exception as e: + verbose_router_logger.error( + "litellm.router.py::_pre_call_checks: failed to count tokens. Returning initial list of deployments. Got - {}".format( + str(e) + ) + ) return _returned_deployments _context_window_error = False @@ -4425,7 +4452,7 @@ class Router: ) continue except Exception as e: - verbose_router_logger.debug("An error occurs - {}".format(str(e))) + verbose_router_logger.error("An error occurs - {}".format(str(e))) _litellm_params = deployment.get("litellm_params", {}) model_id = deployment.get("model_info", {}).get("id", "") diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 3237c8084a..db240e3586 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -16,6 +16,7 @@ sys.path.insert( import os from collections import defaultdict from concurrent.futures import ThreadPoolExecutor +from unittest.mock import AsyncMock, MagicMock, patch import httpx from dotenv import load_dotenv @@ -1884,3 +1885,106 @@ async def test_router_model_usage(mock_response): else: print(f"allowed_fails: {allowed_fails}") raise e + + +@pytest.mark.parametrize( + "model, base_model, llm_provider", + [ + ("azure/gpt-4", None, "azure"), + ("azure/gpt-4", "azure/gpt-4-0125-preview", "azure"), + ("gpt-4", None, "openai"), + ], +) +def test_router_get_model_info(model, base_model, llm_provider): + """ + Test if router get model info works based on provider + + For azure -> only if base model set + For openai -> use model= + """ + router = Router( + model_list=[ + { + "model_name": "gpt-4", + "litellm_params": { + "model": model, + "api_key": "my-fake-key", + "api_base": "my-fake-base", + }, + "model_info": {"base_model": base_model, "id": "1"}, + } + ] + ) + + deployment = router.get_deployment(model_id="1") + + assert deployment is not None + + if llm_provider == "openai" or (base_model is not None and llm_provider == "azure"): + router.get_router_model_info(deployment=deployment.to_json()) + else: + try: + router.get_router_model_info(deployment=deployment.to_json()) + pytest.fail("Expected this to raise model not mapped error") + except Exception as e: + if "This model isn't mapped yet" in str(e): + pass + + +@pytest.mark.parametrize( + "model, base_model, llm_provider", + [ + ("azure/gpt-4", None, "azure"), + ("azure/gpt-4", "azure/gpt-4-0125-preview", "azure"), + ("gpt-4", None, "openai"), + ], +) +def test_router_context_window_pre_call_check(model, base_model, llm_provider): + """ + - For an azure model + - if no base model set + - don't enforce context window limits + """ + try: + model_list = [ + { + "model_name": "gpt-4", + "litellm_params": { + "model": model, + "api_key": "my-fake-key", + "api_base": "my-fake-base", + }, + "model_info": {"base_model": base_model, "id": "1"}, + } + ] + router = Router( + model_list=model_list, + set_verbose=True, + enable_pre_call_checks=True, + num_retries=0, + ) + + litellm.token_counter = MagicMock() + + def token_counter_side_effect(*args, **kwargs): + # Process args and kwargs if needed + return 1000000 + + litellm.token_counter.side_effect = token_counter_side_effect + try: + updated_list = router._pre_call_checks( + model="gpt-4", + healthy_deployments=model_list, + messages=[{"role": "user", "content": "Hey, how's it going?"}], + ) + if llm_provider == "azure" and base_model is None: + assert len(updated_list) == 1 + else: + pytest.fail("Expected to raise an error. Got={}".format(updated_list)) + except Exception as e: + if ( + llm_provider == "azure" and base_model is not None + ) or llm_provider == "openai": + pass + except Exception as e: + pytest.fail(f"Got unexpected exception on router! - {str(e)}")