forked from phoenix/litellm-mirror
fix(token_counter.py): New `get_modified_max_tokens' helper func
Fixes https://github.com/BerriAI/litellm/issues/4439
This commit is contained in:
parent
0c5014c323
commit
d421486a45
6 changed files with 165 additions and 23 deletions
|
@ -738,6 +738,7 @@ openai_image_generation_models = ["dall-e-2", "dall-e-3"]
|
||||||
from .timeout import timeout
|
from .timeout import timeout
|
||||||
from .cost_calculator import completion_cost
|
from .cost_calculator import completion_cost
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||||
|
from litellm.litellm_core_utils.token_counter import get_modified_max_tokens
|
||||||
from .utils import (
|
from .utils import (
|
||||||
client,
|
client,
|
||||||
exception_type,
|
exception_type,
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## Helper utilities for the model response objects
|
## Helper utilities
|
||||||
|
|
||||||
|
|
||||||
def map_finish_reason(
|
def map_finish_reason(
|
||||||
|
|
83
litellm/litellm_core_utils/token_counter.py
Normal file
83
litellm/litellm_core_utils/token_counter.py
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
# What is this?
|
||||||
|
## Helper utilities for token counting
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import verbose_logger
|
||||||
|
|
||||||
|
|
||||||
|
def get_modified_max_tokens(
|
||||||
|
model: str,
|
||||||
|
base_model: str,
|
||||||
|
messages: Optional[list],
|
||||||
|
user_max_tokens: Optional[int],
|
||||||
|
buffer_perc: Optional[float],
|
||||||
|
buffer_num: Optional[float],
|
||||||
|
) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Params:
|
||||||
|
|
||||||
|
Returns the user's max output tokens, adjusted for:
|
||||||
|
- the size of input - for models where input + output can't exceed X
|
||||||
|
- model max output tokens - for models where there is a separate output token limit
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if user_max_tokens is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
## MODEL INFO
|
||||||
|
_model_info = litellm.get_model_info(model=model)
|
||||||
|
|
||||||
|
max_output_tokens = litellm.get_max_tokens(
|
||||||
|
model=base_model
|
||||||
|
) # assume min context window is 4k tokens
|
||||||
|
|
||||||
|
## UNKNOWN MAX OUTPUT TOKENS - return user defined amount
|
||||||
|
if max_output_tokens is None:
|
||||||
|
return user_max_tokens
|
||||||
|
|
||||||
|
input_tokens = litellm.token_counter(model=base_model, messages=messages)
|
||||||
|
|
||||||
|
# token buffer
|
||||||
|
if buffer_perc is None:
|
||||||
|
buffer_perc = 0.1
|
||||||
|
if buffer_num is None:
|
||||||
|
buffer_num = 10
|
||||||
|
token_buffer = max(
|
||||||
|
buffer_perc * input_tokens, buffer_num
|
||||||
|
) # give at least a 10 token buffer. token counting can be imprecise.
|
||||||
|
|
||||||
|
input_tokens += int(token_buffer)
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"max_output_tokens: {max_output_tokens}, user_max_tokens: {user_max_tokens}"
|
||||||
|
)
|
||||||
|
## CASE 1: model input + output can't exceed X - happens when max input = max output, e.g. gpt-3.5-turbo
|
||||||
|
if _model_info["max_input_tokens"] == max_output_tokens:
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"input_tokens: {input_tokens}, max_output_tokens: {max_output_tokens}"
|
||||||
|
)
|
||||||
|
if input_tokens > max_output_tokens:
|
||||||
|
pass # allow call to fail normally - don't set max_tokens to negative.
|
||||||
|
elif (
|
||||||
|
user_max_tokens + input_tokens > max_output_tokens
|
||||||
|
): # we can still modify to keep it positive but below the limit
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"MODIFYING MAX TOKENS - user_max_tokens={user_max_tokens}, input_tokens={input_tokens}, max_output_tokens={max_output_tokens}"
|
||||||
|
)
|
||||||
|
user_max_tokens = int(max_output_tokens - input_tokens)
|
||||||
|
## CASE 2: user_max_tokens> model max output tokens
|
||||||
|
elif user_max_tokens > max_output_tokens:
|
||||||
|
user_max_tokens = max_output_tokens
|
||||||
|
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"litellm.litellm_core_utils.token_counter.py::get_modified_max_tokens() - user_max_tokens: {user_max_tokens}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return user_max_tokens
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error(
|
||||||
|
"litellm.litellm_core_utils.token_counter.py::get_modified_max_tokens() - Error while checking max token limit: {}\nmodel={}, base_model={}".format(
|
||||||
|
str(e), model, base_model
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return user_max_tokens
|
|
@ -1,4 +1,7 @@
|
||||||
model_list:
|
model_list:
|
||||||
|
- model_name: gemini-1.5-flash-gemini
|
||||||
|
litellm_params:
|
||||||
|
model: gemini/gemini-1.5-flash
|
||||||
- model_name: gemini-1.5-flash-gemini
|
- model_name: gemini-1.5-flash-gemini
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: gemini/gemini-1.5-flash
|
model: gemini/gemini-1.5-flash
|
||||||
|
|
|
@ -1,15 +1,25 @@
|
||||||
#### What this tests ####
|
#### What this tests ####
|
||||||
# This tests litellm.token_counter() function
|
# This tests litellm.token_counter() function
|
||||||
|
|
||||||
import sys, os
|
import os
|
||||||
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import time
|
import time
|
||||||
from litellm import token_counter, create_pretrained_tokenizer, encode, decode
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from litellm import (
|
||||||
|
create_pretrained_tokenizer,
|
||||||
|
decode,
|
||||||
|
encode,
|
||||||
|
get_modified_max_tokens,
|
||||||
|
token_counter,
|
||||||
|
)
|
||||||
from litellm.tests.large_text import text
|
from litellm.tests.large_text import text
|
||||||
|
|
||||||
|
|
||||||
|
@ -227,3 +237,55 @@ def test_openai_token_with_image_and_text():
|
||||||
|
|
||||||
token_count = token_counter(model=model, messages=messages)
|
token_count = token_counter(model=model, messages=messages)
|
||||||
print(token_count)
|
print(token_count)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model, base_model, input_tokens, user_max_tokens, expected_value",
|
||||||
|
[
|
||||||
|
("random-model", "random-model", 1024, 1024, 1024),
|
||||||
|
("command", "command", 1000000, None, None), # model max = 4096
|
||||||
|
("command", "command", 4000, 256, 96), # model max = 4096
|
||||||
|
("command", "command", 4000, 10, 10), # model max = 4096
|
||||||
|
("gpt-3.5-turbo", "gpt-3.5-turbo", 4000, 5000, 4096), # model max output = 4096
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_modified_max_tokens(
|
||||||
|
model, base_model, input_tokens, user_max_tokens, expected_value
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
- Test when max_output is not known => expect user_max_tokens
|
||||||
|
- Test when max_output == max_input,
|
||||||
|
- input > max_output, no max_tokens => expect None
|
||||||
|
- input + max_tokens > max_output => expect remainder
|
||||||
|
- input + max_tokens < max_output => expect max_tokens
|
||||||
|
- Test when max_tokens > max_output => expect max_output
|
||||||
|
"""
|
||||||
|
args = locals()
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
litellm.token_counter = MagicMock()
|
||||||
|
|
||||||
|
def _mock_token_counter(*args, **kwargs):
|
||||||
|
return input_tokens
|
||||||
|
|
||||||
|
litellm.token_counter.side_effect = _mock_token_counter
|
||||||
|
print(f"_mock_token_counter: {_mock_token_counter()}")
|
||||||
|
messages = [{"role": "user", "content": "Hello world!"}]
|
||||||
|
|
||||||
|
calculated_value = get_modified_max_tokens(
|
||||||
|
model=model,
|
||||||
|
base_model=base_model,
|
||||||
|
messages=messages,
|
||||||
|
user_max_tokens=user_max_tokens,
|
||||||
|
buffer_perc=0,
|
||||||
|
buffer_num=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if expected_value is None:
|
||||||
|
assert calculated_value is None
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
calculated_value == expected_value
|
||||||
|
), "Got={}, Expected={}, Params={}".format(
|
||||||
|
calculated_value, expected_value, args
|
||||||
|
)
|
||||||
|
|
|
@ -54,6 +54,7 @@ from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_s
|
||||||
from litellm.litellm_core_utils.redact_messages import (
|
from litellm.litellm_core_utils.redact_messages import (
|
||||||
redact_message_input_output_from_logging,
|
redact_message_input_output_from_logging,
|
||||||
)
|
)
|
||||||
|
from litellm.litellm_core_utils.token_counter import get_modified_max_tokens
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
CallTypes,
|
CallTypes,
|
||||||
|
@ -813,7 +814,7 @@ def client(original_function):
|
||||||
kwargs.get("max_tokens", None) is not None
|
kwargs.get("max_tokens", None) is not None
|
||||||
and model is not None
|
and model is not None
|
||||||
and litellm.modify_params
|
and litellm.modify_params
|
||||||
== True # user is okay with params being modified
|
is True # user is okay with params being modified
|
||||||
and (
|
and (
|
||||||
call_type == CallTypes.acompletion.value
|
call_type == CallTypes.acompletion.value
|
||||||
or call_type == CallTypes.completion.value
|
or call_type == CallTypes.completion.value
|
||||||
|
@ -823,28 +824,19 @@ def client(original_function):
|
||||||
base_model = model
|
base_model = model
|
||||||
if kwargs.get("hf_model_name", None) is not None:
|
if kwargs.get("hf_model_name", None) is not None:
|
||||||
base_model = f"huggingface/{kwargs.get('hf_model_name')}"
|
base_model = f"huggingface/{kwargs.get('hf_model_name')}"
|
||||||
max_output_tokens = (
|
|
||||||
get_max_tokens(model=base_model) or 4096
|
|
||||||
) # assume min context window is 4k tokens
|
|
||||||
user_max_tokens = kwargs.get("max_tokens")
|
|
||||||
## Scenario 1: User limit + prompt > model limit
|
|
||||||
messages = None
|
messages = None
|
||||||
if len(args) > 1:
|
if len(args) > 1:
|
||||||
messages = args[1]
|
messages = args[1]
|
||||||
elif kwargs.get("messages", None):
|
elif kwargs.get("messages", None):
|
||||||
messages = kwargs["messages"]
|
messages = kwargs["messages"]
|
||||||
input_tokens = token_counter(model=base_model, messages=messages)
|
user_max_tokens = kwargs.get("max_tokens")
|
||||||
input_tokens += max(
|
modified_max_tokens = get_modified_max_tokens(
|
||||||
0.1 * input_tokens, 10
|
model=model,
|
||||||
) # give at least a 10 token buffer. token counting can be imprecise.
|
base_model=base_model,
|
||||||
if input_tokens > max_output_tokens:
|
messages=messages,
|
||||||
pass # allow call to fail normally
|
user_max_tokens=user_max_tokens,
|
||||||
elif user_max_tokens + input_tokens > max_output_tokens:
|
)
|
||||||
user_max_tokens = max_output_tokens - input_tokens
|
kwargs["max_tokens"] = modified_max_tokens
|
||||||
print_verbose(f"user_max_tokens: {user_max_tokens}")
|
|
||||||
kwargs["max_tokens"] = int(
|
|
||||||
round(user_max_tokens)
|
|
||||||
) # make sure max tokens is always an int
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(f"Error while checking max token limit: {str(e)}")
|
print_verbose(f"Error while checking max token limit: {str(e)}")
|
||||||
# MODEL CALL
|
# MODEL CALL
|
||||||
|
@ -4352,7 +4344,7 @@ def get_utc_datetime():
|
||||||
return datetime.utcnow() # type: ignore
|
return datetime.utcnow() # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def get_max_tokens(model: str):
|
def get_max_tokens(model: str) -> Optional[int]:
|
||||||
"""
|
"""
|
||||||
Get the maximum number of output tokens allowed for a given model.
|
Get the maximum number of output tokens allowed for a given model.
|
||||||
|
|
||||||
|
@ -4406,7 +4398,8 @@ def get_max_tokens(model: str):
|
||||||
return litellm.model_cost[model]["max_tokens"]
|
return litellm.model_cost[model]["max_tokens"]
|
||||||
else:
|
else:
|
||||||
raise Exception()
|
raise Exception()
|
||||||
except:
|
return None
|
||||||
|
except Exception:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Model {model} isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
|
f"Model {model} isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue