fix(utils.py): check if model info is for model with correct provider

Fixes issue where incorrect pricing was used for custom llm provider
This commit is contained in:
Krrish Dholakia 2024-06-13 15:54:24 -07:00
parent d210eccb79
commit 345094a49d
8 changed files with 55 additions and 18 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -55,7 +55,16 @@ model_list:
model: textembedding-gecko-multilingual@001
vertex_project: my-project-9d5c
vertex_location: us-central1
- model_name: lbl/command-r-plus
litellm_params:
model: openai/lbl/command-r-plus
api_key: "os.environ/VLLM_API_KEY"
api_base: http://vllm-command:8000/v1
rpm: 1000
input_cost_per_token: 0
output_cost_per_token: 0
model_info:
max_input_tokens: 80920
assistant_settings:
custom_llm_provider: openai
litellm_params:

View file

@ -11402,7 +11402,7 @@ async def model_info_v2(
for _model in all_models:
# provided model_info in config.yaml
model_info = _model.get("model_info", {})
if debug == True:
if debug is True:
_openai_client = "None"
if llm_router is not None:
_openai_client = (
@ -11427,7 +11427,7 @@ async def model_info_v2(
litellm_model = litellm_params.get("model", None)
try:
litellm_model_info = litellm.get_model_info(model=litellm_model)
except:
except Exception:
litellm_model_info = {}
# 3rd pass on the model, try seeing if we can find model but without the "/" in model cost map
if litellm_model_info == {}:
@ -11438,8 +11438,10 @@ async def model_info_v2(
if len(split_model) > 0:
litellm_model = split_model[-1]
try:
litellm_model_info = litellm.get_model_info(model=litellm_model)
except:
litellm_model_info = litellm.get_model_info(
model=litellm_model, custom_llm_provider=split_model[0]
)
except Exception:
litellm_model_info = {}
for k, v in litellm_model_info.items():
if k not in model_info:
@ -11950,7 +11952,9 @@ async def model_info_v1(
if len(split_model) > 0:
litellm_model = split_model[-1]
try:
litellm_model_info = litellm.get_model_info(model=litellm_model)
litellm_model_info = litellm.get_model_info(
model=litellm_model, custom_llm_provider=split_model[0]
)
except:
litellm_model_info = {}
for k, v in litellm_model_info.items():

View file

@ -7,6 +7,7 @@ sys.path.insert(
) # Adds the parent directory to the system path
import litellm
from litellm import get_model_info
import pytest
def test_get_model_info_simple_model_name():
@ -23,3 +24,16 @@ def test_get_model_info_custom_llm_with_model_name():
"""
model = "anthropic/claude-3-opus-20240229"
litellm.get_model_info(model)
def test_get_model_info_custom_llm_with_same_name_vllm():
"""
Tests if {custom_llm_provider}/{model_name} name given, and model exists in model info, the object is returned
"""
model = "command-r-plus"
provider = "openai" # vllm is openai-compatible
try:
litellm.get_model_info(model, custom_llm_provider=provider)
pytest.fail("Expected get model info to fail for an unmapped model/provider")
except Exception:
pass

View file

@ -6953,13 +6953,14 @@ def get_max_tokens(model: str):
)
def get_model_info(model: str) -> ModelInfo:
def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> ModelInfo:
"""
Get a dict for the maximum tokens (context window),
input_cost_per_token, output_cost_per_token for a given model.
Parameters:
model (str): The name of the model.
- model (str): The name of the model.
- custom_llm_provider (str | null): the provider used for the model. If provided, used to check if the litellm model info is for that provider.
Returns:
dict: A dictionary containing the following information:
@ -7013,12 +7014,14 @@ def get_model_info(model: str) -> ModelInfo:
if model in azure_llms:
model = azure_llms[model]
##########################
if custom_llm_provider is None:
# Get custom_llm_provider
split_model, custom_llm_provider = model, ""
try:
split_model, custom_llm_provider, _, _ = get_llm_provider(model=model)
except:
pass
else:
split_model = model
#########################
supported_openai_params = litellm.get_supported_openai_params(
@ -7043,10 +7046,20 @@ def get_model_info(model: str) -> ModelInfo:
if model in litellm.model_cost:
_model_info = litellm.model_cost[model]
_model_info["supported_openai_params"] = supported_openai_params
if (
"litellm_provider" in _model_info
and _model_info["litellm_provider"] != custom_llm_provider
):
raise Exception
return _model_info
if split_model in litellm.model_cost:
_model_info = litellm.model_cost[split_model]
_model_info["supported_openai_params"] = supported_openai_params
if (
"litellm_provider" in _model_info
and _model_info["litellm_provider"] != custom_llm_provider
):
raise Exception
return _model_info
else:
raise ValueError(

View file

@ -1531,7 +1531,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
<pre className="text-xs">
{model.input_cost
? model.input_cost
: model.litellm_params.input_cost_per_token
: model.litellm_params.input_cost_per_token != null && model.litellm_params.input_cost_per_token != undefined
? (
Number(
model.litellm_params