mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(utils.py): support checking if user defined max tokens exceeds model limit
This commit is contained in:
parent
e71d3f8df4
commit
a0daac212d
1 changed files with 33 additions and 2 deletions
|
@ -2186,6 +2186,34 @@ def client(original_function):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return cached_result
|
return cached_result
|
||||||
|
|
||||||
|
# CHECK MAX TOKENS
|
||||||
|
if (
|
||||||
|
kwargs("max_tokens", None) is not None
|
||||||
|
and model is not None
|
||||||
|
and litellm.drop_params
|
||||||
|
== True # user is okay with params being modified
|
||||||
|
and (
|
||||||
|
call_type == CallTypes.acompletion.value
|
||||||
|
or call_type == CallTypes.completion.value
|
||||||
|
)
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
max_output_tokens = get_max_tokens(model=model)
|
||||||
|
user_max_tokens = kwargs.get("max_tokens")
|
||||||
|
## Scenario 1: User limit > model limit
|
||||||
|
if user_max_tokens > max_output_tokens:
|
||||||
|
user_max_tokens = max_output_tokens
|
||||||
|
## Scenario 2: User limit + prompt > model limit
|
||||||
|
input_tokens = token_counter(
|
||||||
|
model=model, messages=kwargs.get("messages")
|
||||||
|
)
|
||||||
|
if input_tokens > max_output_tokens:
|
||||||
|
pass # allow call to fail normally
|
||||||
|
elif user_max_tokens + input_tokens > max_output_tokens:
|
||||||
|
user_max_tokens = max_output_tokens - input_tokens
|
||||||
|
except Exception as e:
|
||||||
|
print_verbose(f"Error while checking max token limit: {str(e)}")
|
||||||
# MODEL CALL
|
# MODEL CALL
|
||||||
result = original_function(*args, **kwargs)
|
result = original_function(*args, **kwargs)
|
||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
|
@ -4503,7 +4531,7 @@ def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]):
|
||||||
|
|
||||||
def get_max_tokens(model: str):
|
def get_max_tokens(model: str):
|
||||||
"""
|
"""
|
||||||
Get the maximum number of tokens allowed for a given model.
|
Get the maximum number of output tokens allowed for a given model.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
model (str): The name of the model.
|
model (str): The name of the model.
|
||||||
|
@ -4543,7 +4571,10 @@ def get_max_tokens(model: str):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if model in litellm.model_cost:
|
if model in litellm.model_cost:
|
||||||
return litellm.model_cost[model]["max_tokens"]
|
if "max_output_tokens" in litellm.model_cost[model]:
|
||||||
|
return litellm.model_cost[model]["max_output_tokens"]
|
||||||
|
elif "max_tokens" in litellm.model_cost[model]:
|
||||||
|
return litellm.model_cost[model]["max_tokens"]
|
||||||
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
|
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
|
||||||
if custom_llm_provider == "huggingface":
|
if custom_llm_provider == "huggingface":
|
||||||
max_tokens = _get_max_position_embeddings(model_name=model)
|
max_tokens = _get_max_position_embeddings(model_name=model)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue