mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
(v1.0+ breaking change) get_max_tokens -> return int
This commit is contained in:
parent
c162f8b4b0
commit
bd82559553
3 changed files with 63 additions and 9 deletions
|
@ -346,6 +346,7 @@ from .utils import (
|
|||
acreate,
|
||||
get_model_list,
|
||||
get_max_tokens,
|
||||
get_model_info,
|
||||
register_prompt_template,
|
||||
validate_environment,
|
||||
check_valid_key,
|
||||
|
|
|
@ -9,20 +9,21 @@ from litellm import get_max_tokens, model_cost, open_ai_chat_completion_models
|
|||
|
||||
def test_get_gpt3_tokens():
|
||||
max_tokens = get_max_tokens("gpt-3.5-turbo")
|
||||
results = max_tokens['max_tokens']
|
||||
print(results)
|
||||
# test_get_gpt3_tokens()
|
||||
print(max_tokens)
|
||||
assert max_tokens==4097
|
||||
# print(results)
|
||||
test_get_gpt3_tokens()
|
||||
|
||||
def test_get_palm_tokens():
|
||||
# # 🦄🦄🦄🦄🦄🦄🦄🦄
|
||||
max_tokens = get_max_tokens("palm/chat-bison")
|
||||
results = max_tokens['max_tokens']
|
||||
print(results)
|
||||
# test_get_palm_tokens()
|
||||
assert max_tokens == 4096
|
||||
print(max_tokens)
|
||||
test_get_palm_tokens()
|
||||
|
||||
def test_zephyr_hf_tokens():
|
||||
max_tokens = get_max_tokens("huggingface/HuggingFaceH4/zephyr-7b-beta")
|
||||
results = max_tokens["max_tokens"]
|
||||
print(results)
|
||||
print(max_tokens)
|
||||
assert max_tokens == 32768
|
||||
|
||||
test_zephyr_hf_tokens()
|
|
@ -2358,6 +2358,58 @@ def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]):
|
|||
return api_key
|
||||
|
||||
def get_max_tokens(model: str):
|
||||
"""
|
||||
Get the maximum number of tokens allowed for a given model.
|
||||
|
||||
Parameters:
|
||||
model (str): The name of the model.
|
||||
|
||||
Returns:
|
||||
int: The maximum number of tokens allowed for the given model.
|
||||
|
||||
Raises:
|
||||
Exception: If the model is not mapped yet.
|
||||
|
||||
Example:
|
||||
>>> get_max_tokens("gpt-4")
|
||||
8192
|
||||
"""
|
||||
def _get_max_position_embeddings(model_name):
|
||||
# Construct the URL for the config.json file
|
||||
config_url = f"https://huggingface.co/{model_name}/raw/main/config.json"
|
||||
|
||||
try:
|
||||
# Make the HTTP request to get the raw JSON file
|
||||
response = requests.get(config_url)
|
||||
response.raise_for_status() # Raise an exception for bad responses (4xx or 5xx)
|
||||
|
||||
# Parse the JSON response
|
||||
config_json = response.json()
|
||||
|
||||
# Extract and return the max_position_embeddings
|
||||
max_position_embeddings = config_json.get("max_position_embeddings")
|
||||
|
||||
if max_position_embeddings is not None:
|
||||
return max_position_embeddings
|
||||
else:
|
||||
return None
|
||||
except requests.exceptions.RequestException as e:
|
||||
return None
|
||||
|
||||
try:
|
||||
if model in litellm.model_cost:
|
||||
return litellm.model_cost[model]["max_tokens"]
|
||||
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
|
||||
if custom_llm_provider == "huggingface":
|
||||
max_tokens = _get_max_position_embeddings(model_name=model)
|
||||
return max_tokens
|
||||
else:
|
||||
raise Exception()
|
||||
except:
|
||||
raise Exception("This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json")
|
||||
|
||||
|
||||
def get_model_info(model: str):
|
||||
"""
|
||||
Get a dict for the maximum tokens (context window),
|
||||
input_cost_per_token, output_cost_per_token for a given model.
|
||||
|
@ -2377,7 +2429,7 @@ def get_max_tokens(model: str):
|
|||
Exception: If the model is not mapped yet.
|
||||
|
||||
Example:
|
||||
>>> get_max_tokens("gpt-4")
|
||||
>>> get_model_info("gpt-4")
|
||||
{
|
||||
"max_tokens": 8192,
|
||||
"input_cost_per_token": 0.00003,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue