mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(token_counter.py): New `get_modified_max_tokens' helper func
Fixes https://github.com/BerriAI/litellm/issues/4439
This commit is contained in:
parent
0c5014c323
commit
d421486a45
6 changed files with 165 additions and 23 deletions
|
@ -54,6 +54,7 @@ from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_s
|
|||
from litellm.litellm_core_utils.redact_messages import (
|
||||
redact_message_input_output_from_logging,
|
||||
)
|
||||
from litellm.litellm_core_utils.token_counter import get_modified_max_tokens
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.types.utils import (
|
||||
CallTypes,
|
||||
|
@ -813,7 +814,7 @@ def client(original_function):
|
|||
kwargs.get("max_tokens", None) is not None
|
||||
and model is not None
|
||||
and litellm.modify_params
|
||||
== True # user is okay with params being modified
|
||||
is True # user is okay with params being modified
|
||||
and (
|
||||
call_type == CallTypes.acompletion.value
|
||||
or call_type == CallTypes.completion.value
|
||||
|
@ -823,28 +824,19 @@ def client(original_function):
|
|||
base_model = model
|
||||
if kwargs.get("hf_model_name", None) is not None:
|
||||
base_model = f"huggingface/{kwargs.get('hf_model_name')}"
|
||||
max_output_tokens = (
|
||||
get_max_tokens(model=base_model) or 4096
|
||||
) # assume min context window is 4k tokens
|
||||
user_max_tokens = kwargs.get("max_tokens")
|
||||
## Scenario 1: User limit + prompt > model limit
|
||||
messages = None
|
||||
if len(args) > 1:
|
||||
messages = args[1]
|
||||
elif kwargs.get("messages", None):
|
||||
messages = kwargs["messages"]
|
||||
input_tokens = token_counter(model=base_model, messages=messages)
|
||||
input_tokens += max(
|
||||
0.1 * input_tokens, 10
|
||||
) # give at least a 10 token buffer. token counting can be imprecise.
|
||||
if input_tokens > max_output_tokens:
|
||||
pass # allow call to fail normally
|
||||
elif user_max_tokens + input_tokens > max_output_tokens:
|
||||
user_max_tokens = max_output_tokens - input_tokens
|
||||
print_verbose(f"user_max_tokens: {user_max_tokens}")
|
||||
kwargs["max_tokens"] = int(
|
||||
round(user_max_tokens)
|
||||
) # make sure max tokens is always an int
|
||||
user_max_tokens = kwargs.get("max_tokens")
|
||||
modified_max_tokens = get_modified_max_tokens(
|
||||
model=model,
|
||||
base_model=base_model,
|
||||
messages=messages,
|
||||
user_max_tokens=user_max_tokens,
|
||||
)
|
||||
kwargs["max_tokens"] = modified_max_tokens
|
||||
except Exception as e:
|
||||
print_verbose(f"Error while checking max token limit: {str(e)}")
|
||||
# MODEL CALL
|
||||
|
@ -4352,7 +4344,7 @@ def get_utc_datetime():
|
|||
return datetime.utcnow() # type: ignore
|
||||
|
||||
|
||||
def get_max_tokens(model: str):
|
||||
def get_max_tokens(model: str) -> Optional[int]:
|
||||
"""
|
||||
Get the maximum number of output tokens allowed for a given model.
|
||||
|
||||
|
@ -4406,7 +4398,8 @@ def get_max_tokens(model: str):
|
|||
return litellm.model_cost[model]["max_tokens"]
|
||||
else:
|
||||
raise Exception()
|
||||
except:
|
||||
return None
|
||||
except Exception:
|
||||
raise Exception(
|
||||
f"Model {model} isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue