diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index 4ca8731635..13f0070c9c 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -1111,28 +1111,28 @@ def default_image_cost_calculator( f"Looking up cost for models: {model_name_with_quality}, {base_model_name}" ) - # Try model with quality first, fall back to base model name - if model_name_with_quality in litellm.model_cost: - cost_info = litellm.model_cost[model_name_with_quality] - elif base_model_name in litellm.model_cost: - cost_info = litellm.model_cost[base_model_name] - elif model_name_with_v2_quality in litellm.model_cost: - cost_info = litellm.model_cost[model_name_with_v2_quality] - else: - # Try without provider prefix - model_without_provider = f"{size_str}/{model.split('/')[-1]}" - model_with_quality_without_provider = ( - f"{quality}/{model_without_provider}" if quality else model_without_provider - ) + model_without_provider = f"{size_str}/{model.split('/')[-1]}" + model_with_quality_without_provider = ( + f"{quality}/{model_without_provider}" if quality else model_without_provider + ) - if model_with_quality_without_provider in litellm.model_cost: - cost_info = litellm.model_cost[model_with_quality_without_provider] - elif model_without_provider in litellm.model_cost: - cost_info = litellm.model_cost[model_without_provider] - else: - raise Exception( - f"Model not found in cost map. Tried {model_name_with_quality}, {base_model_name}, {model_with_quality_without_provider}, and {model_without_provider}" - ) + # Try model with quality first, fall back to base model name + cost_info: Optional[dict] = None + models_to_check = [ + model_name_with_quality, + base_model_name, + model_name_with_v2_quality, + model_with_quality_without_provider, + model_without_provider, + ] + for model in models_to_check: + if model in litellm.model_cost: + cost_info = litellm.model_cost[model] + break + if cost_info is None: + raise Exception( + f"Model not found in cost map. Tried checking {models_to_check}" + ) return cost_info["input_cost_per_pixel"] * height * width * n