fix(utils.py): check if model info is for model with correct provider

Fixes issue where incorrect pricing was used for custom llm provider
This commit is contained in:
Krrish Dholakia 2024-06-13 15:54:24 -07:00
parent d210eccb79
commit 345094a49d
8 changed files with 55 additions and 18 deletions

View file

@ -6953,13 +6953,14 @@ def get_max_tokens(model: str):
)
def get_model_info(model: str) -> ModelInfo:
def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> ModelInfo:
"""
Get a dict for the maximum tokens (context window),
input_cost_per_token, output_cost_per_token for a given model.
Parameters:
model (str): The name of the model.
- model (str): The name of the model.
- custom_llm_provider (str | null): the provider used for the model. If provided, used to check if the litellm model info is for that provider.
Returns:
dict: A dictionary containing the following information:
@ -7013,12 +7014,14 @@ def get_model_info(model: str) -> ModelInfo:
if model in azure_llms:
model = azure_llms[model]
##########################
# Get custom_llm_provider
split_model, custom_llm_provider = model, ""
try:
split_model, custom_llm_provider, _, _ = get_llm_provider(model=model)
except:
pass
if custom_llm_provider is None:
# Get custom_llm_provider
try:
split_model, custom_llm_provider, _, _ = get_llm_provider(model=model)
except:
pass
else:
split_model = model
#########################
supported_openai_params = litellm.get_supported_openai_params(
@ -7043,10 +7046,20 @@ def get_model_info(model: str) -> ModelInfo:
if model in litellm.model_cost:
_model_info = litellm.model_cost[model]
_model_info["supported_openai_params"] = supported_openai_params
if (
"litellm_provider" in _model_info
and _model_info["litellm_provider"] != custom_llm_provider
):
raise Exception
return _model_info
if split_model in litellm.model_cost:
_model_info = litellm.model_cost[split_model]
_model_info["supported_openai_params"] = supported_openai_params
if (
"litellm_provider" in _model_info
and _model_info["litellm_provider"] != custom_llm_provider
):
raise Exception
return _model_info
else:
raise ValueError(