TestOpenAIGPTImage1

This commit is contained in:
Ishaan Jaff 2025-04-23 12:57:58 -07:00
parent 36ee132514
commit 233d1d5303
5 changed files with 59 additions and 38 deletions

View file

@ -57,6 +57,7 @@ from litellm.llms.vertex_ai.image_generation.cost_calculator import (
from litellm.responses.utils import ResponseAPILoggingUtils from litellm.responses.utils import ResponseAPILoggingUtils
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
HttpxBinaryResponseContent, HttpxBinaryResponseContent,
ImageGenerationRequestQuality,
OpenAIRealtimeStreamList, OpenAIRealtimeStreamList,
OpenAIRealtimeStreamResponseBaseObject, OpenAIRealtimeStreamResponseBaseObject,
OpenAIRealtimeStreamSessionEvents, OpenAIRealtimeStreamSessionEvents,
@ -642,9 +643,9 @@ def completion_cost( # noqa: PLR0915
or isinstance(completion_response, dict) or isinstance(completion_response, dict)
): # tts returns a custom class ): # tts returns a custom class
if isinstance(completion_response, dict): if isinstance(completion_response, dict):
usage_obj: Optional[ usage_obj: Optional[Union[dict, Usage]] = (
Union[dict, Usage] completion_response.get("usage", {})
] = completion_response.get("usage", {}) )
else: else:
usage_obj = getattr(completion_response, "usage", {}) usage_obj = getattr(completion_response, "usage", {})
if isinstance(usage_obj, BaseModel) and not _is_known_usage_objects( if isinstance(usage_obj, BaseModel) and not _is_known_usage_objects(
@ -913,7 +914,7 @@ def completion_cost( # noqa: PLR0915
def get_response_cost_from_hidden_params( def get_response_cost_from_hidden_params(
hidden_params: Union[dict, BaseModel] hidden_params: Union[dict, BaseModel],
) -> Optional[float]: ) -> Optional[float]:
if isinstance(hidden_params, BaseModel): if isinstance(hidden_params, BaseModel):
_hidden_params_dict = hidden_params.model_dump() _hidden_params_dict = hidden_params.model_dump()
@ -1101,6 +1102,11 @@ def default_image_cost_calculator(
f"{quality}/{base_model_name}" if quality else base_model_name f"{quality}/{base_model_name}" if quality else base_model_name
) )
# gpt-image-1 models use low, medium, high quality. If user did not specify quality, use medium fot gpt-image-1 model family
model_name_with_v2_quality = (
f"{ImageGenerationRequestQuality.MEDIUM}/{base_model_name}"
)
verbose_logger.debug( verbose_logger.debug(
f"Looking up cost for models: {model_name_with_quality}, {base_model_name}" f"Looking up cost for models: {model_name_with_quality}, {base_model_name}"
) )
@ -1110,6 +1116,8 @@ def default_image_cost_calculator(
cost_info = litellm.model_cost[model_name_with_quality] cost_info = litellm.model_cost[model_name_with_quality]
elif base_model_name in litellm.model_cost: elif base_model_name in litellm.model_cost:
cost_info = litellm.model_cost[base_model_name] cost_info = litellm.model_cost[base_model_name]
elif model_name_with_v2_quality in litellm.model_cost:
cost_info = litellm.model_cost[model_name_with_v2_quality]
else: else:
# Try without provider prefix # Try without provider prefix
model_without_provider = f"{size_str}/{model.split('/')[-1]}" model_without_provider = f"{size_str}/{model.split('/')[-1]}"

View file

@ -182,6 +182,7 @@ from .types.llms.openai import (
ChatCompletionPredictionContentParam, ChatCompletionPredictionContentParam,
ChatCompletionUserMessage, ChatCompletionUserMessage,
HttpxBinaryResponseContent, HttpxBinaryResponseContent,
ImageGenerationRequestQuality,
) )
from .types.utils import ( from .types.utils import (
LITELLM_IMAGE_VARIATION_PROVIDERS, LITELLM_IMAGE_VARIATION_PROVIDERS,
@ -2688,9 +2689,9 @@ def completion( # type: ignore # noqa: PLR0915
"aws_region_name" not in optional_params "aws_region_name" not in optional_params
or optional_params["aws_region_name"] is None or optional_params["aws_region_name"] is None
): ):
optional_params[ optional_params["aws_region_name"] = (
"aws_region_name" aws_bedrock_client.meta.region_name
] = aws_bedrock_client.meta.region_name )
bedrock_route = BedrockModelInfo.get_bedrock_route(model) bedrock_route = BedrockModelInfo.get_bedrock_route(model)
if bedrock_route == "converse": if bedrock_route == "converse":
@ -4412,9 +4413,9 @@ def adapter_completion(
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs) new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore
translated_response: Optional[ translated_response: Optional[Union[BaseModel, AdapterCompletionStreamWrapper]] = (
Union[BaseModel, AdapterCompletionStreamWrapper] None
] = None )
if isinstance(response, ModelResponse): if isinstance(response, ModelResponse):
translated_response = translation_obj.translate_completion_output_params( translated_response = translation_obj.translate_completion_output_params(
response=response response=response
@ -4567,7 +4568,7 @@ def image_generation( # noqa: PLR0915
prompt: str, prompt: str,
model: Optional[str] = None, model: Optional[str] = None,
n: Optional[int] = None, n: Optional[int] = None,
quality: Optional[str] = None, quality: Optional[Union[str, ImageGenerationRequestQuality]] = None,
response_format: Optional[str] = None, response_format: Optional[str] = None,
size: Optional[str] = None, size: Optional[str] = None,
style: Optional[str] = None, style: Optional[str] = None,
@ -5834,9 +5835,9 @@ def stream_chunk_builder( # noqa: PLR0915
] ]
if len(content_chunks) > 0: if len(content_chunks) > 0:
response["choices"][0]["message"][ response["choices"][0]["message"]["content"] = (
"content" processor.get_combined_content(content_chunks)
] = processor.get_combined_content(content_chunks) )
reasoning_chunks = [ reasoning_chunks = [
chunk chunk
@ -5847,9 +5848,9 @@ def stream_chunk_builder( # noqa: PLR0915
] ]
if len(reasoning_chunks) > 0: if len(reasoning_chunks) > 0:
response["choices"][0]["message"][ response["choices"][0]["message"]["reasoning_content"] = (
"reasoning_content" processor.get_combined_reasoning_content(reasoning_chunks)
] = processor.get_combined_reasoning_content(reasoning_chunks) )
audio_chunks = [ audio_chunks = [
chunk chunk

View file

@ -824,12 +824,12 @@ class OpenAIChatCompletionChunk(ChatCompletionChunk):
class Hyperparameters(BaseModel): class Hyperparameters(BaseModel):
batch_size: Optional[Union[str, int]] = None # "Number of examples in each batch." batch_size: Optional[Union[str, int]] = None # "Number of examples in each batch."
learning_rate_multiplier: Optional[ learning_rate_multiplier: Optional[Union[str, float]] = (
Union[str, float] None # Scaling factor for the learning rate
] = None # Scaling factor for the learning rate )
n_epochs: Optional[ n_epochs: Optional[Union[str, int]] = (
Union[str, int] None # "The number of epochs to train the model for"
] = None # "The number of epochs to train the model for" )
class FineTuningJobCreate(BaseModel): class FineTuningJobCreate(BaseModel):
@ -856,18 +856,18 @@ class FineTuningJobCreate(BaseModel):
model: str # "The name of the model to fine-tune." model: str # "The name of the model to fine-tune."
training_file: str # "The ID of an uploaded file that contains training data." training_file: str # "The ID of an uploaded file that contains training data."
hyperparameters: Optional[ hyperparameters: Optional[Hyperparameters] = (
Hyperparameters None # "The hyperparameters used for the fine-tuning job."
] = None # "The hyperparameters used for the fine-tuning job." )
suffix: Optional[ suffix: Optional[str] = (
str None # "A string of up to 18 characters that will be added to your fine-tuned model name."
] = None # "A string of up to 18 characters that will be added to your fine-tuned model name." )
validation_file: Optional[ validation_file: Optional[str] = (
str None # "The ID of an uploaded file that contains validation data."
] = None # "The ID of an uploaded file that contains validation data." )
integrations: Optional[ integrations: Optional[List[str]] = (
List[str] None # "A list of integrations to enable for your fine-tuning job."
] = None # "A list of integrations to enable for your fine-tuning job." )
seed: Optional[int] = None # "The seed controls the reproducibility of the job." seed: Optional[int] = None # "The seed controls the reproducibility of the job."
@ -1259,3 +1259,12 @@ class OpenAIRealtimeStreamResponseBaseObject(TypedDict):
OpenAIRealtimeStreamList = List[ OpenAIRealtimeStreamList = List[
Union[OpenAIRealtimeStreamResponseBaseObject, OpenAIRealtimeStreamSessionEvents] Union[OpenAIRealtimeStreamResponseBaseObject, OpenAIRealtimeStreamSessionEvents]
] ]
class ImageGenerationRequestQuality(str, Enum):
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
AUTO = "auto"
STANDARD = "standard"
HD = "hd"

View file

@ -66,8 +66,8 @@ class BaseImageGenTest(ABC):
logged_standard_logging_payload = custom_logger.standard_logging_payload logged_standard_logging_payload = custom_logger.standard_logging_payload
print("logged_standard_logging_payload", logged_standard_logging_payload) print("logged_standard_logging_payload", logged_standard_logging_payload)
assert logged_standard_logging_payload is not None assert logged_standard_logging_payload is not None
# assert logged_standard_logging_payload["response_cost"] is not None assert logged_standard_logging_payload["response_cost"] is not None
# assert logged_standard_logging_payload["response_cost"] > 0 assert logged_standard_logging_payload["response_cost"] > 0
from openai.types.images_response import ImagesResponse from openai.types.images_response import ImagesResponse
@ -85,4 +85,4 @@ class BaseImageGenTest(ABC):
if "Your task failed as a result of our safety system." in str(e): if "Your task failed as a result of our safety system." in str(e):
pass pass
else: else:
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")

View file

@ -161,6 +161,9 @@ class TestOpenAIDalle3(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict: def get_base_image_generation_call_args(self) -> dict:
return {"model": "dall-e-3"} return {"model": "dall-e-3"}
class TestOpenAIGPTImage1(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict:
return {"model": "gpt-image-1"}
class TestAzureOpenAIDalle3(BaseImageGenTest): class TestAzureOpenAIDalle3(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict: def get_base_image_generation_call_args(self) -> dict: