fix(utils.py): support max token adjustment for sagemaker

This commit is contained in:
Krrish Dholakia 2024-01-31 19:09:54 -08:00
parent 93a52a2d35
commit a32639fa79
3 changed files with 41 additions and 10 deletions

View file

@ -2199,22 +2199,28 @@ def client(original_function):
)
):
try:
max_output_tokens = get_max_tokens(model=model)
base_model = model
if kwargs.get("hf_model_name", None) is not None:
base_model = f"huggingface/{kwargs.get('hf_model_name')}"
max_output_tokens = (
get_max_tokens(model=base_model) or 4096
) # assume min context window is 4k tokens
user_max_tokens = kwargs.get("max_tokens")
## Scenario 1: User limit > model limit
if user_max_tokens > max_output_tokens:
user_max_tokens = max_output_tokens
## Scenario 2: User limit + prompt > model limit
## Scenario 1: User limit + prompt > model limit
messages = None
if len(args) > 1:
messages = args[1]
elif kwargs.get("messages", None):
messages = kwargs["messages"]
input_tokens = token_counter(model=model, messages=messages)
input_tokens = token_counter(model=base_model, messages=messages)
input_tokens += max(
0.1 * input_tokens, 10
) # give at least a 10 token buffer. token counting can be imprecise.
if input_tokens > max_output_tokens:
pass # allow call to fail normally
elif user_max_tokens + input_tokens > max_output_tokens:
user_max_tokens = max_output_tokens - input_tokens
kwargs["max_tokens"] = user_max_tokens
except Exception as e:
print_verbose(f"Error while checking max token limit: {str(e)}")
# MODEL CALL
@ -4553,7 +4559,6 @@ def get_max_tokens(model: str):
def _get_max_position_embeddings(model_name):
# Construct the URL for the config.json file
config_url = f"https://huggingface.co/{model_name}/raw/main/config.json"
try:
# Make the HTTP request to get the raw JSON file
response = requests.get(config_url)
@ -4561,10 +4566,8 @@ def get_max_tokens(model: str):
# Parse the JSON response
config_json = response.json()
# Extract and return the max_position_embeddings
max_position_embeddings = config_json.get("max_position_embeddings")
if max_position_embeddings is not None:
return max_position_embeddings
else: