diff --git a/litellm/utils.py b/litellm/utils.py index d102476f31..cec43cb8ee 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2186,6 +2186,34 @@ def client(original_function): ) else: return cached_result + + # CHECK MAX TOKENS + if ( + kwargs("max_tokens", None) is not None + and model is not None + and litellm.drop_params + == True # user is okay with params being modified + and ( + call_type == CallTypes.acompletion.value + or call_type == CallTypes.completion.value + ) + ): + try: + max_output_tokens = get_max_tokens(model=model) + user_max_tokens = kwargs.get("max_tokens") + ## Scenario 1: User limit > model limit + if user_max_tokens > max_output_tokens: + user_max_tokens = max_output_tokens + ## Scenario 2: User limit + prompt > model limit + input_tokens = token_counter( + model=model, messages=kwargs.get("messages") + ) + if input_tokens > max_output_tokens: + pass # allow call to fail normally + elif user_max_tokens + input_tokens > max_output_tokens: + user_max_tokens = max_output_tokens - input_tokens + except Exception as e: + print_verbose(f"Error while checking max token limit: {str(e)}") # MODEL CALL result = original_function(*args, **kwargs) end_time = datetime.datetime.now() @@ -4503,7 +4531,7 @@ def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]): def get_max_tokens(model: str): """ - Get the maximum number of tokens allowed for a given model. + Get the maximum number of output tokens allowed for a given model. Parameters: model (str): The name of the model. @@ -4543,7 +4571,10 @@ def get_max_tokens(model: str): try: if model in litellm.model_cost: - return litellm.model_cost[model]["max_tokens"] + if "max_output_tokens" in litellm.model_cost[model]: + return litellm.model_cost[model]["max_output_tokens"] + elif "max_tokens" in litellm.model_cost[model]: + return litellm.model_cost[model]["max_tokens"] model, custom_llm_provider, _, _ = get_llm_provider(model=model) if custom_llm_provider == "huggingface": max_tokens = _get_max_position_embeddings(model_name=model)