fix(token_counter.py): New `get_modified_max_tokens' helper func

Fixes https://github.com/BerriAI/litellm/issues/4439
This commit is contained in:
Krrish Dholakia 2024-06-27 15:38:09 -07:00
parent 0c5014c323
commit d421486a45
6 changed files with 165 additions and 23 deletions

View file

@ -54,6 +54,7 @@ from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_s
from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_logging,
)
from litellm.litellm_core_utils.token_counter import get_modified_max_tokens
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.utils import (
CallTypes,
@ -813,7 +814,7 @@ def client(original_function):
kwargs.get("max_tokens", None) is not None
and model is not None
and litellm.modify_params
== True # user is okay with params being modified
is True # user is okay with params being modified
and (
call_type == CallTypes.acompletion.value
or call_type == CallTypes.completion.value
@ -823,28 +824,19 @@ def client(original_function):
base_model = model
if kwargs.get("hf_model_name", None) is not None:
base_model = f"huggingface/{kwargs.get('hf_model_name')}"
max_output_tokens = (
get_max_tokens(model=base_model) or 4096
) # assume min context window is 4k tokens
user_max_tokens = kwargs.get("max_tokens")
## Scenario 1: User limit + prompt > model limit
messages = None
if len(args) > 1:
messages = args[1]
elif kwargs.get("messages", None):
messages = kwargs["messages"]
input_tokens = token_counter(model=base_model, messages=messages)
input_tokens += max(
0.1 * input_tokens, 10
) # give at least a 10 token buffer. token counting can be imprecise.
if input_tokens > max_output_tokens:
pass # allow call to fail normally
elif user_max_tokens + input_tokens > max_output_tokens:
user_max_tokens = max_output_tokens - input_tokens
print_verbose(f"user_max_tokens: {user_max_tokens}")
kwargs["max_tokens"] = int(
round(user_max_tokens)
) # make sure max tokens is always an int
user_max_tokens = kwargs.get("max_tokens")
modified_max_tokens = get_modified_max_tokens(
model=model,
base_model=base_model,
messages=messages,
user_max_tokens=user_max_tokens,
)
kwargs["max_tokens"] = modified_max_tokens
except Exception as e:
print_verbose(f"Error while checking max token limit: {str(e)}")
# MODEL CALL
@ -4352,7 +4344,7 @@ def get_utc_datetime():
return datetime.utcnow() # type: ignore
def get_max_tokens(model: str):
def get_max_tokens(model: str) -> Optional[int]:
"""
Get the maximum number of output tokens allowed for a given model.
@ -4406,7 +4398,8 @@ def get_max_tokens(model: str):
return litellm.model_cost[model]["max_tokens"]
else:
raise Exception()
except:
return None
except Exception:
raise Exception(
f"Model {model} isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
)