mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(utils.py): support max token adjustment for sagemaker
This commit is contained in:
parent
93a52a2d35
commit
a32639fa79
3 changed files with 41 additions and 10 deletions
|
@ -2199,22 +2199,28 @@ def client(original_function):
|
|||
)
|
||||
):
|
||||
try:
|
||||
max_output_tokens = get_max_tokens(model=model)
|
||||
base_model = model
|
||||
if kwargs.get("hf_model_name", None) is not None:
|
||||
base_model = f"huggingface/{kwargs.get('hf_model_name')}"
|
||||
max_output_tokens = (
|
||||
get_max_tokens(model=base_model) or 4096
|
||||
) # assume min context window is 4k tokens
|
||||
user_max_tokens = kwargs.get("max_tokens")
|
||||
## Scenario 1: User limit > model limit
|
||||
if user_max_tokens > max_output_tokens:
|
||||
user_max_tokens = max_output_tokens
|
||||
## Scenario 2: User limit + prompt > model limit
|
||||
## Scenario 1: User limit + prompt > model limit
|
||||
messages = None
|
||||
if len(args) > 1:
|
||||
messages = args[1]
|
||||
elif kwargs.get("messages", None):
|
||||
messages = kwargs["messages"]
|
||||
input_tokens = token_counter(model=model, messages=messages)
|
||||
input_tokens = token_counter(model=base_model, messages=messages)
|
||||
input_tokens += max(
|
||||
0.1 * input_tokens, 10
|
||||
) # give at least a 10 token buffer. token counting can be imprecise.
|
||||
if input_tokens > max_output_tokens:
|
||||
pass # allow call to fail normally
|
||||
elif user_max_tokens + input_tokens > max_output_tokens:
|
||||
user_max_tokens = max_output_tokens - input_tokens
|
||||
kwargs["max_tokens"] = user_max_tokens
|
||||
except Exception as e:
|
||||
print_verbose(f"Error while checking max token limit: {str(e)}")
|
||||
# MODEL CALL
|
||||
|
@ -4553,7 +4559,6 @@ def get_max_tokens(model: str):
|
|||
def _get_max_position_embeddings(model_name):
|
||||
# Construct the URL for the config.json file
|
||||
config_url = f"https://huggingface.co/{model_name}/raw/main/config.json"
|
||||
|
||||
try:
|
||||
# Make the HTTP request to get the raw JSON file
|
||||
response = requests.get(config_url)
|
||||
|
@ -4561,10 +4566,8 @@ def get_max_tokens(model: str):
|
|||
|
||||
# Parse the JSON response
|
||||
config_json = response.json()
|
||||
|
||||
# Extract and return the max_position_embeddings
|
||||
max_position_embeddings = config_json.get("max_position_embeddings")
|
||||
|
||||
if max_position_embeddings is not None:
|
||||
return max_position_embeddings
|
||||
else:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue