fix(utils.py): check if model info is for model with correct provider

Fixes issue where incorrect pricing was used for custom llm provider
This commit is contained in:
Krrish Dholakia 2024-06-13 15:54:24 -07:00
parent d210eccb79
commit 345094a49d
8 changed files with 55 additions and 18 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -55,7 +55,16 @@ model_list:
model: textembedding-gecko-multilingual@001 model: textembedding-gecko-multilingual@001
vertex_project: my-project-9d5c vertex_project: my-project-9d5c
vertex_location: us-central1 vertex_location: us-central1
- model_name: lbl/command-r-plus
litellm_params:
model: openai/lbl/command-r-plus
api_key: "os.environ/VLLM_API_KEY"
api_base: http://vllm-command:8000/v1
rpm: 1000
input_cost_per_token: 0
output_cost_per_token: 0
model_info:
max_input_tokens: 80920
assistant_settings: assistant_settings:
custom_llm_provider: openai custom_llm_provider: openai
litellm_params: litellm_params:

View file

@ -11402,7 +11402,7 @@ async def model_info_v2(
for _model in all_models: for _model in all_models:
# provided model_info in config.yaml # provided model_info in config.yaml
model_info = _model.get("model_info", {}) model_info = _model.get("model_info", {})
if debug == True: if debug is True:
_openai_client = "None" _openai_client = "None"
if llm_router is not None: if llm_router is not None:
_openai_client = ( _openai_client = (
@ -11427,7 +11427,7 @@ async def model_info_v2(
litellm_model = litellm_params.get("model", None) litellm_model = litellm_params.get("model", None)
try: try:
litellm_model_info = litellm.get_model_info(model=litellm_model) litellm_model_info = litellm.get_model_info(model=litellm_model)
except: except Exception:
litellm_model_info = {} litellm_model_info = {}
# 3rd pass on the model, try seeing if we can find model but without the "/" in model cost map # 3rd pass on the model, try seeing if we can find model but without the "/" in model cost map
if litellm_model_info == {}: if litellm_model_info == {}:
@ -11438,8 +11438,10 @@ async def model_info_v2(
if len(split_model) > 0: if len(split_model) > 0:
litellm_model = split_model[-1] litellm_model = split_model[-1]
try: try:
litellm_model_info = litellm.get_model_info(model=litellm_model) litellm_model_info = litellm.get_model_info(
except: model=litellm_model, custom_llm_provider=split_model[0]
)
except Exception:
litellm_model_info = {} litellm_model_info = {}
for k, v in litellm_model_info.items(): for k, v in litellm_model_info.items():
if k not in model_info: if k not in model_info:
@ -11950,7 +11952,9 @@ async def model_info_v1(
if len(split_model) > 0: if len(split_model) > 0:
litellm_model = split_model[-1] litellm_model = split_model[-1]
try: try:
litellm_model_info = litellm.get_model_info(model=litellm_model) litellm_model_info = litellm.get_model_info(
model=litellm_model, custom_llm_provider=split_model[0]
)
except: except:
litellm_model_info = {} litellm_model_info = {}
for k, v in litellm_model_info.items(): for k, v in litellm_model_info.items():

View file

@ -7,6 +7,7 @@ sys.path.insert(
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import litellm import litellm
from litellm import get_model_info from litellm import get_model_info
import pytest
def test_get_model_info_simple_model_name(): def test_get_model_info_simple_model_name():
@ -23,3 +24,16 @@ def test_get_model_info_custom_llm_with_model_name():
""" """
model = "anthropic/claude-3-opus-20240229" model = "anthropic/claude-3-opus-20240229"
litellm.get_model_info(model) litellm.get_model_info(model)
def test_get_model_info_custom_llm_with_same_name_vllm():
"""
Tests if {custom_llm_provider}/{model_name} name given, and model exists in model info, the object is returned
"""
model = "command-r-plus"
provider = "openai" # vllm is openai-compatible
try:
litellm.get_model_info(model, custom_llm_provider=provider)
pytest.fail("Expected get model info to fail for an unmapped model/provider")
except Exception:
pass

View file

@ -6953,13 +6953,14 @@ def get_max_tokens(model: str):
) )
def get_model_info(model: str) -> ModelInfo: def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> ModelInfo:
""" """
Get a dict for the maximum tokens (context window), Get a dict for the maximum tokens (context window),
input_cost_per_token, output_cost_per_token for a given model. input_cost_per_token, output_cost_per_token for a given model.
Parameters: Parameters:
model (str): The name of the model. - model (str): The name of the model.
- custom_llm_provider (str | null): the provider used for the model. If provided, used to check if the litellm model info is for that provider.
Returns: Returns:
dict: A dictionary containing the following information: dict: A dictionary containing the following information:
@ -7013,12 +7014,14 @@ def get_model_info(model: str) -> ModelInfo:
if model in azure_llms: if model in azure_llms:
model = azure_llms[model] model = azure_llms[model]
########################## ##########################
if custom_llm_provider is None:
# Get custom_llm_provider # Get custom_llm_provider
split_model, custom_llm_provider = model, ""
try: try:
split_model, custom_llm_provider, _, _ = get_llm_provider(model=model) split_model, custom_llm_provider, _, _ = get_llm_provider(model=model)
except: except:
pass pass
else:
split_model = model
######################### #########################
supported_openai_params = litellm.get_supported_openai_params( supported_openai_params = litellm.get_supported_openai_params(
@ -7043,10 +7046,20 @@ def get_model_info(model: str) -> ModelInfo:
if model in litellm.model_cost: if model in litellm.model_cost:
_model_info = litellm.model_cost[model] _model_info = litellm.model_cost[model]
_model_info["supported_openai_params"] = supported_openai_params _model_info["supported_openai_params"] = supported_openai_params
if (
"litellm_provider" in _model_info
and _model_info["litellm_provider"] != custom_llm_provider
):
raise Exception
return _model_info return _model_info
if split_model in litellm.model_cost: if split_model in litellm.model_cost:
_model_info = litellm.model_cost[split_model] _model_info = litellm.model_cost[split_model]
_model_info["supported_openai_params"] = supported_openai_params _model_info["supported_openai_params"] = supported_openai_params
if (
"litellm_provider" in _model_info
and _model_info["litellm_provider"] != custom_llm_provider
):
raise Exception
return _model_info return _model_info
else: else:
raise ValueError( raise ValueError(

View file

@ -1531,7 +1531,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
<pre className="text-xs"> <pre className="text-xs">
{model.input_cost {model.input_cost
? model.input_cost ? model.input_cost
: model.litellm_params.input_cost_per_token : model.litellm_params.input_cost_per_token != null && model.litellm_params.input_cost_per_token != undefined
? ( ? (
Number( Number(
model.litellm_params model.litellm_params