fix(token_counter.py): New `get_modified_max_tokens' helper func

Fixes https://github.com/BerriAI/litellm/issues/4439
This commit is contained in:
Krrish Dholakia 2024-06-27 15:38:09 -07:00
parent 0c5014c323
commit d421486a45
6 changed files with 165 additions and 23 deletions

View file

@ -738,6 +738,7 @@ openai_image_generation_models = ["dall-e-2", "dall-e-3"]
from .timeout import timeout
from .cost_calculator import completion_cost
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.litellm_core_utils.token_counter import get_modified_max_tokens
from .utils import (
client,
exception_type,

View file

@ -1,5 +1,5 @@
# What is this?
## Helper utilities for the model response objects
## Helper utilities
def map_finish_reason(

View file

@ -0,0 +1,83 @@
# What is this?
## Helper utilities for token counting
from typing import Optional
import litellm
from litellm import verbose_logger
def get_modified_max_tokens(
model: str,
base_model: str,
messages: Optional[list],
user_max_tokens: Optional[int],
buffer_perc: Optional[float],
buffer_num: Optional[float],
) -> Optional[int]:
"""
Params:
Returns the user's max output tokens, adjusted for:
- the size of input - for models where input + output can't exceed X
- model max output tokens - for models where there is a separate output token limit
"""
try:
if user_max_tokens is None:
return None
## MODEL INFO
_model_info = litellm.get_model_info(model=model)
max_output_tokens = litellm.get_max_tokens(
model=base_model
) # assume min context window is 4k tokens
## UNKNOWN MAX OUTPUT TOKENS - return user defined amount
if max_output_tokens is None:
return user_max_tokens
input_tokens = litellm.token_counter(model=base_model, messages=messages)
# token buffer
if buffer_perc is None:
buffer_perc = 0.1
if buffer_num is None:
buffer_num = 10
token_buffer = max(
buffer_perc * input_tokens, buffer_num
) # give at least a 10 token buffer. token counting can be imprecise.
input_tokens += int(token_buffer)
verbose_logger.debug(
f"max_output_tokens: {max_output_tokens}, user_max_tokens: {user_max_tokens}"
)
## CASE 1: model input + output can't exceed X - happens when max input = max output, e.g. gpt-3.5-turbo
if _model_info["max_input_tokens"] == max_output_tokens:
verbose_logger.debug(
f"input_tokens: {input_tokens}, max_output_tokens: {max_output_tokens}"
)
if input_tokens > max_output_tokens:
pass # allow call to fail normally - don't set max_tokens to negative.
elif (
user_max_tokens + input_tokens > max_output_tokens
): # we can still modify to keep it positive but below the limit
verbose_logger.debug(
f"MODIFYING MAX TOKENS - user_max_tokens={user_max_tokens}, input_tokens={input_tokens}, max_output_tokens={max_output_tokens}"
)
user_max_tokens = int(max_output_tokens - input_tokens)
## CASE 2: user_max_tokens> model max output tokens
elif user_max_tokens > max_output_tokens:
user_max_tokens = max_output_tokens
verbose_logger.debug(
f"litellm.litellm_core_utils.token_counter.py::get_modified_max_tokens() - user_max_tokens: {user_max_tokens}"
)
return user_max_tokens
except Exception as e:
verbose_logger.error(
"litellm.litellm_core_utils.token_counter.py::get_modified_max_tokens() - Error while checking max token limit: {}\nmodel={}, base_model={}".format(
str(e), model, base_model
)
)
return user_max_tokens

View file

@ -1,4 +1,7 @@
model_list:
- model_name: gemini-1.5-flash-gemini
litellm_params:
model: gemini/gemini-1.5-flash
- model_name: gemini-1.5-flash-gemini
litellm_params:
model: gemini/gemini-1.5-flash

View file

@ -1,15 +1,25 @@
#### What this tests ####
# This tests litellm.token_counter() function
import sys, os
import os
import sys
import traceback
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import time
from litellm import token_counter, create_pretrained_tokenizer, encode, decode
from unittest.mock import AsyncMock, MagicMock, patch
from litellm import (
create_pretrained_tokenizer,
decode,
encode,
get_modified_max_tokens,
token_counter,
)
from litellm.tests.large_text import text
@ -227,3 +237,55 @@ def test_openai_token_with_image_and_text():
token_count = token_counter(model=model, messages=messages)
print(token_count)
@pytest.mark.parametrize(
"model, base_model, input_tokens, user_max_tokens, expected_value",
[
("random-model", "random-model", 1024, 1024, 1024),
("command", "command", 1000000, None, None), # model max = 4096
("command", "command", 4000, 256, 96), # model max = 4096
("command", "command", 4000, 10, 10), # model max = 4096
("gpt-3.5-turbo", "gpt-3.5-turbo", 4000, 5000, 4096), # model max output = 4096
],
)
def test_get_modified_max_tokens(
model, base_model, input_tokens, user_max_tokens, expected_value
):
"""
- Test when max_output is not known => expect user_max_tokens
- Test when max_output == max_input,
- input > max_output, no max_tokens => expect None
- input + max_tokens > max_output => expect remainder
- input + max_tokens < max_output => expect max_tokens
- Test when max_tokens > max_output => expect max_output
"""
args = locals()
import litellm
litellm.token_counter = MagicMock()
def _mock_token_counter(*args, **kwargs):
return input_tokens
litellm.token_counter.side_effect = _mock_token_counter
print(f"_mock_token_counter: {_mock_token_counter()}")
messages = [{"role": "user", "content": "Hello world!"}]
calculated_value = get_modified_max_tokens(
model=model,
base_model=base_model,
messages=messages,
user_max_tokens=user_max_tokens,
buffer_perc=0,
buffer_num=0,
)
if expected_value is None:
assert calculated_value is None
else:
assert (
calculated_value == expected_value
), "Got={}, Expected={}, Params={}".format(
calculated_value, expected_value, args
)

View file

@ -54,6 +54,7 @@ from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_s
from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_logging,
)
from litellm.litellm_core_utils.token_counter import get_modified_max_tokens
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.utils import (
CallTypes,
@ -813,7 +814,7 @@ def client(original_function):
kwargs.get("max_tokens", None) is not None
and model is not None
and litellm.modify_params
== True # user is okay with params being modified
is True # user is okay with params being modified
and (
call_type == CallTypes.acompletion.value
or call_type == CallTypes.completion.value
@ -823,28 +824,19 @@ def client(original_function):
base_model = model
if kwargs.get("hf_model_name", None) is not None:
base_model = f"huggingface/{kwargs.get('hf_model_name')}"
max_output_tokens = (
get_max_tokens(model=base_model) or 4096
) # assume min context window is 4k tokens
user_max_tokens = kwargs.get("max_tokens")
## Scenario 1: User limit + prompt > model limit
messages = None
if len(args) > 1:
messages = args[1]
elif kwargs.get("messages", None):
messages = kwargs["messages"]
input_tokens = token_counter(model=base_model, messages=messages)
input_tokens += max(
0.1 * input_tokens, 10
) # give at least a 10 token buffer. token counting can be imprecise.
if input_tokens > max_output_tokens:
pass # allow call to fail normally
elif user_max_tokens + input_tokens > max_output_tokens:
user_max_tokens = max_output_tokens - input_tokens
print_verbose(f"user_max_tokens: {user_max_tokens}")
kwargs["max_tokens"] = int(
round(user_max_tokens)
) # make sure max tokens is always an int
user_max_tokens = kwargs.get("max_tokens")
modified_max_tokens = get_modified_max_tokens(
model=model,
base_model=base_model,
messages=messages,
user_max_tokens=user_max_tokens,
)
kwargs["max_tokens"] = modified_max_tokens
except Exception as e:
print_verbose(f"Error while checking max token limit: {str(e)}")
# MODEL CALL
@ -4352,7 +4344,7 @@ def get_utc_datetime():
return datetime.utcnow() # type: ignore
def get_max_tokens(model: str):
def get_max_tokens(model: str) -> Optional[int]:
"""
Get the maximum number of output tokens allowed for a given model.
@ -4406,7 +4398,8 @@ def get_max_tokens(model: str):
return litellm.model_cost[model]["max_tokens"]
else:
raise Exception()
except:
return None
except Exception:
raise Exception(
f"Model {model} isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
)