feat(get_max_tokens): get max tokens for huggingface hub models

This commit is contained in:
Krrish Dholakia 2023-11-15 15:25:40 -08:00
parent 1a705bfbcb
commit f84db3ce14
2 changed files with 42 additions and 3 deletions

View file

@ -2354,8 +2354,42 @@ def get_max_tokens(model: str):
"mode": "chat"
}
"""
def _get_max_position_embeddings(model_name):
# Construct the URL for the config.json file
config_url = f"https://huggingface.co/{model_name}/raw/main/config.json"
try:
# Make the HTTP request to get the raw JSON file
response = requests.get(config_url)
response.raise_for_status() # Raise an exception for bad responses (4xx or 5xx)
# Parse the JSON response
config_json = response.json()
# Extract and return the max_position_embeddings
max_position_embeddings = config_json.get("max_position_embeddings")
if max_position_embeddings is not None:
return max_position_embeddings
else:
return None
except requests.exceptions.RequestException as e:
return None
try:
return litellm.model_cost[model]
if model in litellm.model_cost:
return litellm.model_cost[model]
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
if custom_llm_provider == "huggingface":
max_tokens = _get_max_position_embeddings(model_name=model)
return {
"max_tokens": max_tokens,
"input_cost_per_token": 0,
"output_cost_per_token": 0,
"litellm_provider": "huggingface",
"mode": "chat"
}
else:
raise Exception()
except:
raise Exception("This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json")