fix(utils.py): support get_max_tokens() call with same model_name as completion

Closes https://github.com/BerriAI/litellm/issues/3921
This commit is contained in:
Krrish Dholakia 2024-05-31 21:37:25 -07:00
parent b8df5d1a01
commit 7523f803d2
2 changed files with 19 additions and 0 deletions

View file

@ -7065,6 +7065,11 @@ def get_max_tokens(model: str):
if custom_llm_provider == "huggingface":
max_tokens = _get_max_position_embeddings(model_name=model)
return max_tokens
if model in litellm.model_cost: # check if extracted model is in model_list
if "max_output_tokens" in litellm.model_cost[model]:
return litellm.model_cost[model]["max_output_tokens"]
elif "max_tokens" in litellm.model_cost[model]:
return litellm.model_cost[model]["max_tokens"]
else:
raise Exception()
except: