diff --git a/litellm/utils.py b/litellm/utils.py index 98a9c34b47..9ccc0dfa20 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2291,6 +2291,11 @@ def register_model(model_cost: Union[str, dict]): # noqa: PLR0915 loaded_model_cost = litellm.get_model_cost_map(url=model_cost) for key, value in loaded_model_cost.items(): + ## Change, if present, input_cost_per_token and output_cost_per_token to float + if "input_cost_per_token" in value: + value["input_cost_per_token"] = float(value["input_cost_per_token"]) + if "output_cost_per_token" in value: + value["output_cost_per_token"] = float(value["output_cost_per_token"]) ## get model info ## try: existing_model: dict = cast(dict, get_model_info(model=key))