mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(utils.py): support get_max_tokens() call with same model_name as completion
Closes https://github.com/BerriAI/litellm/issues/3921
This commit is contained in:
parent
b8df5d1a01
commit
7523f803d2
2 changed files with 19 additions and 0 deletions
|
@ -22,6 +22,7 @@ from litellm.utils import (
|
||||||
token_counter,
|
token_counter,
|
||||||
create_pretrained_tokenizer,
|
create_pretrained_tokenizer,
|
||||||
create_tokenizer,
|
create_tokenizer,
|
||||||
|
get_max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assuming your trim_messages, shorten_message_to_fit_limit, and get_token_count functions are all in a module named 'message_utils'
|
# Assuming your trim_messages, shorten_message_to_fit_limit, and get_token_count functions are all in a module named 'message_utils'
|
||||||
|
@ -372,3 +373,16 @@ def test_supports_function_calling():
|
||||||
assert litellm.supports_function_calling(model="claude-2") == False
|
assert litellm.supports_function_calling(model="claude-2") == False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_max_token_unit_test():
|
||||||
|
"""
|
||||||
|
More complete testing in `test_completion_cost.py`
|
||||||
|
"""
|
||||||
|
model = "bedrock/anthropic.claude-3-haiku-20240307-v1:0"
|
||||||
|
|
||||||
|
max_tokens = get_max_tokens(
|
||||||
|
model
|
||||||
|
) # Returns a number instead of throwing an Exception
|
||||||
|
|
||||||
|
assert isinstance(max_tokens, int)
|
||||||
|
|
|
@ -7065,6 +7065,11 @@ def get_max_tokens(model: str):
|
||||||
if custom_llm_provider == "huggingface":
|
if custom_llm_provider == "huggingface":
|
||||||
max_tokens = _get_max_position_embeddings(model_name=model)
|
max_tokens = _get_max_position_embeddings(model_name=model)
|
||||||
return max_tokens
|
return max_tokens
|
||||||
|
if model in litellm.model_cost: # check if extracted model is in model_list
|
||||||
|
if "max_output_tokens" in litellm.model_cost[model]:
|
||||||
|
return litellm.model_cost[model]["max_output_tokens"]
|
||||||
|
elif "max_tokens" in litellm.model_cost[model]:
|
||||||
|
return litellm.model_cost[model]["max_tokens"]
|
||||||
else:
|
else:
|
||||||
raise Exception()
|
raise Exception()
|
||||||
except:
|
except:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue