diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 640a3b2cf..78d7dc70c 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -4,7 +4,17 @@ model_list: model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0 api_key: my-fake-key aws_bedrock_runtime_endpoint: http://127.0.0.1:8000 + mock_response: "Hello world 1" + model_info: + max_input_tokens: 0 # trigger context window fallback + - model_name: my-fake-model + litellm_params: + model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0 + api_key: my-fake-key + aws_bedrock_runtime_endpoint: http://127.0.0.1:8000 + mock_response: "Hello world 2" + model_info: + max_input_tokens: 0 -litellm_settings: - success_callback: ["langfuse"] - failure_callback: ["langfuse"] +router_settings: + enable_pre_call_checks: True diff --git a/litellm/router.py b/litellm/router.py index e9b0cc00a..6163da487 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -404,6 +404,7 @@ class Router: litellm.failure_callback = [self.deployment_callback_on_failure] print( # noqa f"Intialized router with Routing strategy: {self.routing_strategy}\n\n" + f"Routing enable_pre_call_checks: {self.enable_pre_call_checks}\n\n" f"Routing fallbacks: {self.fallbacks}\n\n" f"Routing content fallbacks: {self.content_policy_fallbacks}\n\n" f"Routing context window fallbacks: {self.context_window_fallbacks}\n\n" @@ -3915,9 +3916,38 @@ class Router: raise Exception("Model invalid format - {}".format(type(model))) return None + def get_router_model_info(self, deployment: dict) -> ModelMapInfo: + """ + For a given model id, return the model info (max tokens, input cost, output cost, etc.). + + Augment litellm info with additional params set in `model_info`. + + Returns + - ModelInfo - If found -> typed dict with max tokens, input cost, etc. + """ + ## SET MODEL NAME + base_model = deployment.get("model_info", {}).get("base_model", None) + if base_model is None: + base_model = deployment.get("litellm_params", {}).get("base_model", None) + model = base_model or deployment.get("litellm_params", {}).get("model", None) + + ## GET LITELLM MODEL INFO + model_info = litellm.get_model_info(model=model) + + ## CHECK USER SET MODEL INFO + user_model_info = deployment.get("model_info", {}) + + model_info.update(user_model_info) + + return model_info + def get_model_info(self, id: str) -> Optional[dict]: """ For a given model id, return the model info + + Returns + - dict: the model in list with 'model_name', 'litellm_params', Optional['model_info'] + - None: could not find deployment in list """ for model in self.model_list: if "model_info" in model and "id" in model["model_info"]: @@ -4307,6 +4337,7 @@ class Router: return _returned_deployments _context_window_error = False + _potential_error_str = "" _rate_limit_error = False ## get model group RPM ## @@ -4327,7 +4358,7 @@ class Router: model = base_model or deployment.get("litellm_params", {}).get( "model", None ) - model_info = litellm.get_model_info(model=model) + model_info = self.get_router_model_info(deployment=deployment) if ( isinstance(model_info, dict) @@ -4339,6 +4370,11 @@ class Router: ): invalid_model_indices.append(idx) _context_window_error = True + _potential_error_str += ( + "Model={}, Max Input Tokens={}, Got={}".format( + model, model_info["max_input_tokens"], input_tokens + ) + ) continue except Exception as e: verbose_router_logger.debug("An error occurs - {}".format(str(e))) @@ -4440,7 +4476,9 @@ class Router: ) elif _context_window_error == True: raise litellm.ContextWindowExceededError( - message="Context Window exceeded for given call", + message="litellm._pre_call_checks: Context Window exceeded for given call. No models have context window large enough for this call.\n{}".format( + _potential_error_str + ), model=model, llm_provider="", response=httpx.Response( diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 2e8814327..84ea9e1c9 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -755,6 +755,7 @@ def test_router_context_window_check_pre_call_check_in_group(): "api_version": os.getenv("AZURE_API_VERSION"), "api_base": os.getenv("AZURE_API_BASE"), "base_model": "azure/gpt-35-turbo", + "mock_response": "Hello world 1!", }, }, { @@ -762,6 +763,7 @@ def test_router_context_window_check_pre_call_check_in_group(): "litellm_params": { # params for litellm completion/embedding call "model": "gpt-3.5-turbo-1106", "api_key": os.getenv("OPENAI_API_KEY"), + "mock_response": "Hello world 2!", }, }, ] @@ -777,6 +779,9 @@ def test_router_context_window_check_pre_call_check_in_group(): ) print(f"response: {response}") + + assert response.choices[0].message.content == "Hello world 2!" + assert False except Exception as e: pytest.fail(f"Got unexpected exception on router! - {str(e)}")