forked from phoenix/litellm-mirror
fix(utils.py): fix model registeration to model cost map
Fixes https://github.com/BerriAI/litellm/issues/4972
This commit is contained in:
parent
142f4fefd0
commit
46634af06f
5 changed files with 73 additions and 14 deletions
|
@ -2148,6 +2148,13 @@ def supports_parallel_function_calling(model: str):
|
|||
|
||||
|
||||
####### HELPER FUNCTIONS ################
|
||||
def _update_dictionary(existing_dict: dict, new_dict: dict) -> dict:
|
||||
for k, v in new_dict.items():
|
||||
existing_dict[k] = v
|
||||
|
||||
return existing_dict
|
||||
|
||||
|
||||
def register_model(model_cost: Union[str, dict]):
|
||||
"""
|
||||
Register new / Override existing models (and their pricing) to specific providers.
|
||||
|
@ -2170,8 +2177,17 @@ def register_model(model_cost: Union[str, dict]):
|
|||
loaded_model_cost = litellm.get_model_cost_map(url=model_cost)
|
||||
|
||||
for key, value in loaded_model_cost.items():
|
||||
## get model info ##
|
||||
try:
|
||||
existing_model = get_model_info(model=key)
|
||||
model_cost_key = existing_model["key"]
|
||||
except Exception:
|
||||
existing_model = {}
|
||||
model_cost_key = key
|
||||
## override / add new keys to the existing model cost dictionary
|
||||
litellm.model_cost.setdefault(key, {}).update(value)
|
||||
litellm.model_cost.setdefault(model_cost_key, {}).update(
|
||||
_update_dictionary(existing_model, value)
|
||||
)
|
||||
verbose_logger.debug(f"{key} added to model cost map")
|
||||
# add new model names to provider lists
|
||||
if value.get("litellm_provider") == "openai":
|
||||
|
@ -4858,6 +4874,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
|||
|
||||
Returns:
|
||||
dict: A dictionary containing the following information:
|
||||
key: Required[str] # the key in litellm.model_cost which is returned
|
||||
max_tokens: Required[Optional[int]]
|
||||
max_input_tokens: Required[Optional[int]]
|
||||
max_output_tokens: Required[Optional[int]]
|
||||
|
@ -4959,6 +4976,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
|||
if custom_llm_provider == "huggingface":
|
||||
max_tokens = _get_max_position_embeddings(model_name=model)
|
||||
return ModelInfo(
|
||||
key=model,
|
||||
max_tokens=max_tokens, # type: ignore
|
||||
max_input_tokens=None,
|
||||
max_output_tokens=None,
|
||||
|
@ -4979,6 +4997,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
|||
3. 'split_model' in litellm.model_cost. Checks "llama3-8b-8192" in litellm.model_cost if model="groq/llama3-8b-8192"
|
||||
"""
|
||||
if combined_model_name in litellm.model_cost:
|
||||
key = combined_model_name
|
||||
_model_info = litellm.model_cost[combined_model_name]
|
||||
_model_info["supported_openai_params"] = supported_openai_params
|
||||
if (
|
||||
|
@ -4992,6 +5011,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
|||
else:
|
||||
raise Exception
|
||||
elif model in litellm.model_cost:
|
||||
key = model
|
||||
_model_info = litellm.model_cost[model]
|
||||
_model_info["supported_openai_params"] = supported_openai_params
|
||||
if (
|
||||
|
@ -5005,6 +5025,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
|||
else:
|
||||
raise Exception
|
||||
elif split_model in litellm.model_cost:
|
||||
key = split_model
|
||||
_model_info = litellm.model_cost[split_model]
|
||||
_model_info["supported_openai_params"] = supported_openai_params
|
||||
if (
|
||||
|
@ -5027,6 +5048,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
|||
_model_info["supports_response_schema"] = True
|
||||
|
||||
return ModelInfo(
|
||||
key=key,
|
||||
max_tokens=_model_info.get("max_tokens", None),
|
||||
max_input_tokens=_model_info.get("max_input_tokens", None),
|
||||
max_output_tokens=_model_info.get("max_output_tokens", None),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue