fix(utils.py): fix model registeration to model cost map

Fixes https://github.com/BerriAI/litellm/issues/4972
This commit is contained in:
Krrish Dholakia 2024-07-30 18:15:00 -07:00
parent 142f4fefd0
commit 46634af06f
5 changed files with 73 additions and 14 deletions

View file

@ -2148,6 +2148,13 @@ def supports_parallel_function_calling(model: str):
####### HELPER FUNCTIONS ################
def _update_dictionary(existing_dict: dict, new_dict: dict) -> dict:
for k, v in new_dict.items():
existing_dict[k] = v
return existing_dict
def register_model(model_cost: Union[str, dict]):
"""
Register new / Override existing models (and their pricing) to specific providers.
@ -2170,8 +2177,17 @@ def register_model(model_cost: Union[str, dict]):
loaded_model_cost = litellm.get_model_cost_map(url=model_cost)
for key, value in loaded_model_cost.items():
## get model info ##
try:
existing_model = get_model_info(model=key)
model_cost_key = existing_model["key"]
except Exception:
existing_model = {}
model_cost_key = key
## override / add new keys to the existing model cost dictionary
litellm.model_cost.setdefault(key, {}).update(value)
litellm.model_cost.setdefault(model_cost_key, {}).update(
_update_dictionary(existing_model, value)
)
verbose_logger.debug(f"{key} added to model cost map")
# add new model names to provider lists
if value.get("litellm_provider") == "openai":
@ -4858,6 +4874,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
Returns:
dict: A dictionary containing the following information:
key: Required[str] # the key in litellm.model_cost which is returned
max_tokens: Required[Optional[int]]
max_input_tokens: Required[Optional[int]]
max_output_tokens: Required[Optional[int]]
@ -4959,6 +4976,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
if custom_llm_provider == "huggingface":
max_tokens = _get_max_position_embeddings(model_name=model)
return ModelInfo(
key=model,
max_tokens=max_tokens, # type: ignore
max_input_tokens=None,
max_output_tokens=None,
@ -4979,6 +4997,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
3. 'split_model' in litellm.model_cost. Checks "llama3-8b-8192" in litellm.model_cost if model="groq/llama3-8b-8192"
"""
if combined_model_name in litellm.model_cost:
key = combined_model_name
_model_info = litellm.model_cost[combined_model_name]
_model_info["supported_openai_params"] = supported_openai_params
if (
@ -4992,6 +5011,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
else:
raise Exception
elif model in litellm.model_cost:
key = model
_model_info = litellm.model_cost[model]
_model_info["supported_openai_params"] = supported_openai_params
if (
@ -5005,6 +5025,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
else:
raise Exception
elif split_model in litellm.model_cost:
key = split_model
_model_info = litellm.model_cost[split_model]
_model_info["supported_openai_params"] = supported_openai_params
if (
@ -5027,6 +5048,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
_model_info["supports_response_schema"] = True
return ModelInfo(
key=key,
max_tokens=_model_info.get("max_tokens", None),
max_input_tokens=_model_info.get("max_input_tokens", None),
max_output_tokens=_model_info.get("max_output_tokens", None),