From 2e58e47b43f38df3b892784a6b99bac070df6000 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 23 Apr 2025 15:16:40 -0700 Subject: [PATCH] [Bug Fix] Add Cost Tracking for gpt-image-1 when quality is unspecified (#10247) * TestOpenAIGPTImage1 * fixes for cost calc * fix ImageGenerationRequestQuality.MEDIUM --- litellm/cost_calculator.py | 54 +++++++++++-------- litellm/main.py | 27 +++++----- litellm/types/llms/openai.py | 45 +++++++++------- .../base_image_generation_test.py | 6 +-- .../image_gen_tests/test_image_generation.py | 3 ++ 5 files changed, 78 insertions(+), 57 deletions(-) diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index 7f3d4fcc9f..eafd924bc6 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -57,6 +57,7 @@ from litellm.llms.vertex_ai.image_generation.cost_calculator import ( from litellm.responses.utils import ResponseAPILoggingUtils from litellm.types.llms.openai import ( HttpxBinaryResponseContent, + ImageGenerationRequestQuality, OpenAIRealtimeStreamList, OpenAIRealtimeStreamResponseBaseObject, OpenAIRealtimeStreamSessionEvents, @@ -642,9 +643,9 @@ def completion_cost( # noqa: PLR0915 or isinstance(completion_response, dict) ): # tts returns a custom class if isinstance(completion_response, dict): - usage_obj: Optional[ - Union[dict, Usage] - ] = completion_response.get("usage", {}) + usage_obj: Optional[Union[dict, Usage]] = ( + completion_response.get("usage", {}) + ) else: usage_obj = getattr(completion_response, "usage", {}) if isinstance(usage_obj, BaseModel) and not _is_known_usage_objects( @@ -913,7 +914,7 @@ def completion_cost( # noqa: PLR0915 def get_response_cost_from_hidden_params( - hidden_params: Union[dict, BaseModel] + hidden_params: Union[dict, BaseModel], ) -> Optional[float]: if isinstance(hidden_params, BaseModel): _hidden_params_dict = hidden_params.model_dump() @@ -1101,30 +1102,37 @@ def default_image_cost_calculator( f"{quality}/{base_model_name}" if quality else base_model_name ) + # gpt-image-1 models use low, medium, high quality. If user did not specify quality, use medium fot gpt-image-1 model family + model_name_with_v2_quality = ( + f"{ImageGenerationRequestQuality.MEDIUM.value}/{base_model_name}" + ) + verbose_logger.debug( f"Looking up cost for models: {model_name_with_quality}, {base_model_name}" ) - # Try model with quality first, fall back to base model name - if model_name_with_quality in litellm.model_cost: - cost_info = litellm.model_cost[model_name_with_quality] - elif base_model_name in litellm.model_cost: - cost_info = litellm.model_cost[base_model_name] - else: - # Try without provider prefix - model_without_provider = f"{size_str}/{model.split('/')[-1]}" - model_with_quality_without_provider = ( - f"{quality}/{model_without_provider}" if quality else model_without_provider - ) + model_without_provider = f"{size_str}/{model.split('/')[-1]}" + model_with_quality_without_provider = ( + f"{quality}/{model_without_provider}" if quality else model_without_provider + ) - if model_with_quality_without_provider in litellm.model_cost: - cost_info = litellm.model_cost[model_with_quality_without_provider] - elif model_without_provider in litellm.model_cost: - cost_info = litellm.model_cost[model_without_provider] - else: - raise Exception( - f"Model not found in cost map. Tried {model_name_with_quality}, {base_model_name}, {model_with_quality_without_provider}, and {model_without_provider}" - ) + # Try model with quality first, fall back to base model name + cost_info: Optional[dict] = None + models_to_check = [ + model_name_with_quality, + base_model_name, + model_name_with_v2_quality, + model_with_quality_without_provider, + model_without_provider, + ] + for model in models_to_check: + if model in litellm.model_cost: + cost_info = litellm.model_cost[model] + break + if cost_info is None: + raise Exception( + f"Model not found in cost map. Tried checking {models_to_check}" + ) return cost_info["input_cost_per_pixel"] * height * width * n diff --git a/litellm/main.py b/litellm/main.py index 80486fbe02..de0716fd96 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -182,6 +182,7 @@ from .types.llms.openai import ( ChatCompletionPredictionContentParam, ChatCompletionUserMessage, HttpxBinaryResponseContent, + ImageGenerationRequestQuality, ) from .types.utils import ( LITELLM_IMAGE_VARIATION_PROVIDERS, @@ -2688,9 +2689,9 @@ def completion( # type: ignore # noqa: PLR0915 "aws_region_name" not in optional_params or optional_params["aws_region_name"] is None ): - optional_params[ - "aws_region_name" - ] = aws_bedrock_client.meta.region_name + optional_params["aws_region_name"] = ( + aws_bedrock_client.meta.region_name + ) bedrock_route = BedrockModelInfo.get_bedrock_route(model) if bedrock_route == "converse": @@ -4412,9 +4413,9 @@ def adapter_completion( new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs) response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore - translated_response: Optional[ - Union[BaseModel, AdapterCompletionStreamWrapper] - ] = None + translated_response: Optional[Union[BaseModel, AdapterCompletionStreamWrapper]] = ( + None + ) if isinstance(response, ModelResponse): translated_response = translation_obj.translate_completion_output_params( response=response @@ -4567,7 +4568,7 @@ def image_generation( # noqa: PLR0915 prompt: str, model: Optional[str] = None, n: Optional[int] = None, - quality: Optional[str] = None, + quality: Optional[Union[str, ImageGenerationRequestQuality]] = None, response_format: Optional[str] = None, size: Optional[str] = None, style: Optional[str] = None, @@ -5834,9 +5835,9 @@ def stream_chunk_builder( # noqa: PLR0915 ] if len(content_chunks) > 0: - response["choices"][0]["message"][ - "content" - ] = processor.get_combined_content(content_chunks) + response["choices"][0]["message"]["content"] = ( + processor.get_combined_content(content_chunks) + ) reasoning_chunks = [ chunk @@ -5847,9 +5848,9 @@ def stream_chunk_builder( # noqa: PLR0915 ] if len(reasoning_chunks) > 0: - response["choices"][0]["message"][ - "reasoning_content" - ] = processor.get_combined_reasoning_content(reasoning_chunks) + response["choices"][0]["message"]["reasoning_content"] = ( + processor.get_combined_reasoning_content(reasoning_chunks) + ) audio_chunks = [ chunk diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index dc45ebe5cc..be5f7585bf 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -824,12 +824,12 @@ class OpenAIChatCompletionChunk(ChatCompletionChunk): class Hyperparameters(BaseModel): batch_size: Optional[Union[str, int]] = None # "Number of examples in each batch." - learning_rate_multiplier: Optional[ - Union[str, float] - ] = None # Scaling factor for the learning rate - n_epochs: Optional[ - Union[str, int] - ] = None # "The number of epochs to train the model for" + learning_rate_multiplier: Optional[Union[str, float]] = ( + None # Scaling factor for the learning rate + ) + n_epochs: Optional[Union[str, int]] = ( + None # "The number of epochs to train the model for" + ) class FineTuningJobCreate(BaseModel): @@ -856,18 +856,18 @@ class FineTuningJobCreate(BaseModel): model: str # "The name of the model to fine-tune." training_file: str # "The ID of an uploaded file that contains training data." - hyperparameters: Optional[ - Hyperparameters - ] = None # "The hyperparameters used for the fine-tuning job." - suffix: Optional[ - str - ] = None # "A string of up to 18 characters that will be added to your fine-tuned model name." - validation_file: Optional[ - str - ] = None # "The ID of an uploaded file that contains validation data." - integrations: Optional[ - List[str] - ] = None # "A list of integrations to enable for your fine-tuning job." + hyperparameters: Optional[Hyperparameters] = ( + None # "The hyperparameters used for the fine-tuning job." + ) + suffix: Optional[str] = ( + None # "A string of up to 18 characters that will be added to your fine-tuned model name." + ) + validation_file: Optional[str] = ( + None # "The ID of an uploaded file that contains validation data." + ) + integrations: Optional[List[str]] = ( + None # "A list of integrations to enable for your fine-tuning job." + ) seed: Optional[int] = None # "The seed controls the reproducibility of the job." @@ -1259,3 +1259,12 @@ class OpenAIRealtimeStreamResponseBaseObject(TypedDict): OpenAIRealtimeStreamList = List[ Union[OpenAIRealtimeStreamResponseBaseObject, OpenAIRealtimeStreamSessionEvents] ] + + +class ImageGenerationRequestQuality(str, Enum): + LOW = "low" + MEDIUM = "medium" + HIGH = "high" + AUTO = "auto" + STANDARD = "standard" + HD = "hd" diff --git a/tests/image_gen_tests/base_image_generation_test.py b/tests/image_gen_tests/base_image_generation_test.py index cf01390e07..76fa2f5540 100644 --- a/tests/image_gen_tests/base_image_generation_test.py +++ b/tests/image_gen_tests/base_image_generation_test.py @@ -66,8 +66,8 @@ class BaseImageGenTest(ABC): logged_standard_logging_payload = custom_logger.standard_logging_payload print("logged_standard_logging_payload", logged_standard_logging_payload) assert logged_standard_logging_payload is not None - # assert logged_standard_logging_payload["response_cost"] is not None - # assert logged_standard_logging_payload["response_cost"] > 0 + assert logged_standard_logging_payload["response_cost"] is not None + assert logged_standard_logging_payload["response_cost"] > 0 from openai.types.images_response import ImagesResponse @@ -85,4 +85,4 @@ class BaseImageGenTest(ABC): if "Your task failed as a result of our safety system." in str(e): pass else: - pytest.fail(f"An exception occurred - {str(e)}") + pytest.fail(f"An exception occurred - {str(e)}") \ No newline at end of file diff --git a/tests/image_gen_tests/test_image_generation.py b/tests/image_gen_tests/test_image_generation.py index bee9d1f7d5..21f96095a7 100644 --- a/tests/image_gen_tests/test_image_generation.py +++ b/tests/image_gen_tests/test_image_generation.py @@ -161,6 +161,9 @@ class TestOpenAIDalle3(BaseImageGenTest): def get_base_image_generation_call_args(self) -> dict: return {"model": "dall-e-3"} +class TestOpenAIGPTImage1(BaseImageGenTest): + def get_base_image_generation_call_args(self) -> dict: + return {"model": "gpt-image-1"} class TestAzureOpenAIDalle3(BaseImageGenTest): def get_base_image_generation_call_args(self) -> dict: