diff --git a/litellm/__init__.py b/litellm/__init__.py index f1cc32cd1..a8d9a80a2 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -738,6 +738,7 @@ openai_image_generation_models = ["dall-e-2", "dall-e-3"] from .timeout import timeout from .cost_calculator import completion_cost from litellm.litellm_core_utils.litellm_logging import Logging +from litellm.litellm_core_utils.token_counter import get_modified_max_tokens from .utils import ( client, exception_type, diff --git a/litellm/litellm_core_utils/core_helpers.py b/litellm/litellm_core_utils/core_helpers.py index 7b911895d..a325a6885 100644 --- a/litellm/litellm_core_utils/core_helpers.py +++ b/litellm/litellm_core_utils/core_helpers.py @@ -1,5 +1,5 @@ # What is this? -## Helper utilities for the model response objects +## Helper utilities def map_finish_reason( diff --git a/litellm/litellm_core_utils/token_counter.py b/litellm/litellm_core_utils/token_counter.py new file mode 100644 index 000000000..ebc0765c0 --- /dev/null +++ b/litellm/litellm_core_utils/token_counter.py @@ -0,0 +1,83 @@ +# What is this? +## Helper utilities for token counting +from typing import Optional + +import litellm +from litellm import verbose_logger + + +def get_modified_max_tokens( + model: str, + base_model: str, + messages: Optional[list], + user_max_tokens: Optional[int], + buffer_perc: Optional[float], + buffer_num: Optional[float], +) -> Optional[int]: + """ + Params: + + Returns the user's max output tokens, adjusted for: + - the size of input - for models where input + output can't exceed X + - model max output tokens - for models where there is a separate output token limit + """ + try: + if user_max_tokens is None: + return None + + ## MODEL INFO + _model_info = litellm.get_model_info(model=model) + + max_output_tokens = litellm.get_max_tokens( + model=base_model + ) # assume min context window is 4k tokens + + ## UNKNOWN MAX OUTPUT TOKENS - return user defined amount + if max_output_tokens is None: + return user_max_tokens + + input_tokens = litellm.token_counter(model=base_model, messages=messages) + + # token buffer + if buffer_perc is None: + buffer_perc = 0.1 + if buffer_num is None: + buffer_num = 10 + token_buffer = max( + buffer_perc * input_tokens, buffer_num + ) # give at least a 10 token buffer. token counting can be imprecise. + + input_tokens += int(token_buffer) + verbose_logger.debug( + f"max_output_tokens: {max_output_tokens}, user_max_tokens: {user_max_tokens}" + ) + ## CASE 1: model input + output can't exceed X - happens when max input = max output, e.g. gpt-3.5-turbo + if _model_info["max_input_tokens"] == max_output_tokens: + verbose_logger.debug( + f"input_tokens: {input_tokens}, max_output_tokens: {max_output_tokens}" + ) + if input_tokens > max_output_tokens: + pass # allow call to fail normally - don't set max_tokens to negative. + elif ( + user_max_tokens + input_tokens > max_output_tokens + ): # we can still modify to keep it positive but below the limit + verbose_logger.debug( + f"MODIFYING MAX TOKENS - user_max_tokens={user_max_tokens}, input_tokens={input_tokens}, max_output_tokens={max_output_tokens}" + ) + user_max_tokens = int(max_output_tokens - input_tokens) + ## CASE 2: user_max_tokens> model max output tokens + elif user_max_tokens > max_output_tokens: + user_max_tokens = max_output_tokens + + verbose_logger.debug( + f"litellm.litellm_core_utils.token_counter.py::get_modified_max_tokens() - user_max_tokens: {user_max_tokens}" + ) + + return user_max_tokens + except Exception as e: + verbose_logger.error( + "litellm.litellm_core_utils.token_counter.py::get_modified_max_tokens() - Error while checking max token limit: {}\nmodel={}, base_model={}".format( + str(e), model, base_model + ) + ) + return user_max_tokens diff --git a/litellm/proxy/_super_secret_config.yaml b/litellm/proxy/_super_secret_config.yaml index c570c08cf..b8c26fd2a 100644 --- a/litellm/proxy/_super_secret_config.yaml +++ b/litellm/proxy/_super_secret_config.yaml @@ -1,4 +1,7 @@ model_list: +- model_name: gemini-1.5-flash-gemini + litellm_params: + model: gemini/gemini-1.5-flash - model_name: gemini-1.5-flash-gemini litellm_params: model: gemini/gemini-1.5-flash diff --git a/litellm/tests/test_token_counter.py b/litellm/tests/test_token_counter.py index 2c3eb89fd..e61762131 100644 --- a/litellm/tests/test_token_counter.py +++ b/litellm/tests/test_token_counter.py @@ -1,15 +1,25 @@ #### What this tests #### # This tests litellm.token_counter() function -import sys, os +import os +import sys import traceback + import pytest sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import time -from litellm import token_counter, create_pretrained_tokenizer, encode, decode +from unittest.mock import AsyncMock, MagicMock, patch + +from litellm import ( + create_pretrained_tokenizer, + decode, + encode, + get_modified_max_tokens, + token_counter, +) from litellm.tests.large_text import text @@ -227,3 +237,55 @@ def test_openai_token_with_image_and_text(): token_count = token_counter(model=model, messages=messages) print(token_count) + + +@pytest.mark.parametrize( + "model, base_model, input_tokens, user_max_tokens, expected_value", + [ + ("random-model", "random-model", 1024, 1024, 1024), + ("command", "command", 1000000, None, None), # model max = 4096 + ("command", "command", 4000, 256, 96), # model max = 4096 + ("command", "command", 4000, 10, 10), # model max = 4096 + ("gpt-3.5-turbo", "gpt-3.5-turbo", 4000, 5000, 4096), # model max output = 4096 + ], +) +def test_get_modified_max_tokens( + model, base_model, input_tokens, user_max_tokens, expected_value +): + """ + - Test when max_output is not known => expect user_max_tokens + - Test when max_output == max_input, + - input > max_output, no max_tokens => expect None + - input + max_tokens > max_output => expect remainder + - input + max_tokens < max_output => expect max_tokens + - Test when max_tokens > max_output => expect max_output + """ + args = locals() + import litellm + + litellm.token_counter = MagicMock() + + def _mock_token_counter(*args, **kwargs): + return input_tokens + + litellm.token_counter.side_effect = _mock_token_counter + print(f"_mock_token_counter: {_mock_token_counter()}") + messages = [{"role": "user", "content": "Hello world!"}] + + calculated_value = get_modified_max_tokens( + model=model, + base_model=base_model, + messages=messages, + user_max_tokens=user_max_tokens, + buffer_perc=0, + buffer_num=0, + ) + + if expected_value is None: + assert calculated_value is None + else: + assert ( + calculated_value == expected_value + ), "Got={}, Expected={}, Params={}".format( + calculated_value, expected_value, args + ) diff --git a/litellm/utils.py b/litellm/utils.py index 53f5f9848..c53e8f338 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -54,6 +54,7 @@ from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_s from litellm.litellm_core_utils.redact_messages import ( redact_message_input_output_from_logging, ) +from litellm.litellm_core_utils.token_counter import get_modified_max_tokens from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.types.utils import ( CallTypes, @@ -813,7 +814,7 @@ def client(original_function): kwargs.get("max_tokens", None) is not None and model is not None and litellm.modify_params - == True # user is okay with params being modified + is True # user is okay with params being modified and ( call_type == CallTypes.acompletion.value or call_type == CallTypes.completion.value @@ -823,28 +824,19 @@ def client(original_function): base_model = model if kwargs.get("hf_model_name", None) is not None: base_model = f"huggingface/{kwargs.get('hf_model_name')}" - max_output_tokens = ( - get_max_tokens(model=base_model) or 4096 - ) # assume min context window is 4k tokens - user_max_tokens = kwargs.get("max_tokens") - ## Scenario 1: User limit + prompt > model limit messages = None if len(args) > 1: messages = args[1] elif kwargs.get("messages", None): messages = kwargs["messages"] - input_tokens = token_counter(model=base_model, messages=messages) - input_tokens += max( - 0.1 * input_tokens, 10 - ) # give at least a 10 token buffer. token counting can be imprecise. - if input_tokens > max_output_tokens: - pass # allow call to fail normally - elif user_max_tokens + input_tokens > max_output_tokens: - user_max_tokens = max_output_tokens - input_tokens - print_verbose(f"user_max_tokens: {user_max_tokens}") - kwargs["max_tokens"] = int( - round(user_max_tokens) - ) # make sure max tokens is always an int + user_max_tokens = kwargs.get("max_tokens") + modified_max_tokens = get_modified_max_tokens( + model=model, + base_model=base_model, + messages=messages, + user_max_tokens=user_max_tokens, + ) + kwargs["max_tokens"] = modified_max_tokens except Exception as e: print_verbose(f"Error while checking max token limit: {str(e)}") # MODEL CALL @@ -4374,7 +4366,7 @@ def get_utc_datetime(): return datetime.utcnow() # type: ignore -def get_max_tokens(model: str): +def get_max_tokens(model: str) -> Optional[int]: """ Get the maximum number of output tokens allowed for a given model. @@ -4428,7 +4420,8 @@ def get_max_tokens(model: str): return litellm.model_cost[model]["max_tokens"] else: raise Exception() - except: + return None + except Exception: raise Exception( f"Model {model} isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json" )