diff --git a/litellm/tests/test_utils.py b/litellm/tests/test_utils.py index 57b93df9c..8f99bd665 100644 --- a/litellm/tests/test_utils.py +++ b/litellm/tests/test_utils.py @@ -22,6 +22,7 @@ from litellm.utils import ( token_counter, create_pretrained_tokenizer, create_tokenizer, + get_max_tokens, ) # Assuming your trim_messages, shorten_message_to_fit_limit, and get_token_count functions are all in a module named 'message_utils' @@ -372,3 +373,16 @@ def test_supports_function_calling(): assert litellm.supports_function_calling(model="claude-2") == False except Exception as e: pytest.fail(f"Error occurred: {e}") + + +def test_get_max_token_unit_test(): + """ + More complete testing in `test_completion_cost.py` + """ + model = "bedrock/anthropic.claude-3-haiku-20240307-v1:0" + + max_tokens = get_max_tokens( + model + ) # Returns a number instead of throwing an Exception + + assert isinstance(max_tokens, int) diff --git a/litellm/utils.py b/litellm/utils.py index 3e9fdccd9..e9af85601 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -7065,6 +7065,11 @@ def get_max_tokens(model: str): if custom_llm_provider == "huggingface": max_tokens = _get_max_position_embeddings(model_name=model) return max_tokens + if model in litellm.model_cost: # check if extracted model is in model_list + if "max_output_tokens" in litellm.model_cost[model]: + return litellm.model_cost[model]["max_output_tokens"] + elif "max_tokens" in litellm.model_cost[model]: + return litellm.model_cost[model]["max_tokens"] else: raise Exception() except: