forked from phoenix/litellm-mirror
feat(get_max_tokens): get max tokens for huggingface hub models
This commit is contained in:
parent
1a705bfbcb
commit
f84db3ce14
2 changed files with 42 additions and 3 deletions
|
@ -7,8 +7,6 @@ sys.path.insert(
|
||||||
import time
|
import time
|
||||||
from litellm import get_max_tokens, model_cost, open_ai_chat_completion_models
|
from litellm import get_max_tokens, model_cost, open_ai_chat_completion_models
|
||||||
|
|
||||||
print(get_max_tokens("gpt-3.5-turbo"))
|
|
||||||
|
|
||||||
def test_get_gpt3_tokens():
|
def test_get_gpt3_tokens():
|
||||||
max_tokens = get_max_tokens("gpt-3.5-turbo")
|
max_tokens = get_max_tokens("gpt-3.5-turbo")
|
||||||
results = max_tokens['max_tokens']
|
results = max_tokens['max_tokens']
|
||||||
|
@ -21,3 +19,10 @@ def test_get_palm_tokens():
|
||||||
results = max_tokens['max_tokens']
|
results = max_tokens['max_tokens']
|
||||||
print(results)
|
print(results)
|
||||||
# test_get_palm_tokens()
|
# test_get_palm_tokens()
|
||||||
|
|
||||||
|
def test_zephyr_hf_tokens():
|
||||||
|
max_tokens = get_max_tokens("huggingface/HuggingFaceH4/zephyr-7b-beta")
|
||||||
|
results = max_tokens["max_tokens"]
|
||||||
|
print(results)
|
||||||
|
|
||||||
|
test_zephyr_hf_tokens()
|
|
@ -2354,8 +2354,42 @@ def get_max_tokens(model: str):
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
def _get_max_position_embeddings(model_name):
|
||||||
|
# Construct the URL for the config.json file
|
||||||
|
config_url = f"https://huggingface.co/{model_name}/raw/main/config.json"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Make the HTTP request to get the raw JSON file
|
||||||
|
response = requests.get(config_url)
|
||||||
|
response.raise_for_status() # Raise an exception for bad responses (4xx or 5xx)
|
||||||
|
|
||||||
|
# Parse the JSON response
|
||||||
|
config_json = response.json()
|
||||||
|
|
||||||
|
# Extract and return the max_position_embeddings
|
||||||
|
max_position_embeddings = config_json.get("max_position_embeddings")
|
||||||
|
|
||||||
|
if max_position_embeddings is not None:
|
||||||
|
return max_position_embeddings
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
return None
|
||||||
try:
|
try:
|
||||||
return litellm.model_cost[model]
|
if model in litellm.model_cost:
|
||||||
|
return litellm.model_cost[model]
|
||||||
|
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
|
||||||
|
if custom_llm_provider == "huggingface":
|
||||||
|
max_tokens = _get_max_position_embeddings(model_name=model)
|
||||||
|
return {
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"input_cost_per_token": 0,
|
||||||
|
"output_cost_per_token": 0,
|
||||||
|
"litellm_provider": "huggingface",
|
||||||
|
"mode": "chat"
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise Exception()
|
||||||
except:
|
except:
|
||||||
raise Exception("This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json")
|
raise Exception("This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json")
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue