mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
TestOpenAIGPTImage1
This commit is contained in:
parent
36ee132514
commit
233d1d5303
5 changed files with 59 additions and 38 deletions
|
@ -57,6 +57,7 @@ from litellm.llms.vertex_ai.image_generation.cost_calculator import (
|
|||
from litellm.responses.utils import ResponseAPILoggingUtils
|
||||
from litellm.types.llms.openai import (
|
||||
HttpxBinaryResponseContent,
|
||||
ImageGenerationRequestQuality,
|
||||
OpenAIRealtimeStreamList,
|
||||
OpenAIRealtimeStreamResponseBaseObject,
|
||||
OpenAIRealtimeStreamSessionEvents,
|
||||
|
@ -642,9 +643,9 @@ def completion_cost( # noqa: PLR0915
|
|||
or isinstance(completion_response, dict)
|
||||
): # tts returns a custom class
|
||||
if isinstance(completion_response, dict):
|
||||
usage_obj: Optional[
|
||||
Union[dict, Usage]
|
||||
] = completion_response.get("usage", {})
|
||||
usage_obj: Optional[Union[dict, Usage]] = (
|
||||
completion_response.get("usage", {})
|
||||
)
|
||||
else:
|
||||
usage_obj = getattr(completion_response, "usage", {})
|
||||
if isinstance(usage_obj, BaseModel) and not _is_known_usage_objects(
|
||||
|
@ -913,7 +914,7 @@ def completion_cost( # noqa: PLR0915
|
|||
|
||||
|
||||
def get_response_cost_from_hidden_params(
|
||||
hidden_params: Union[dict, BaseModel]
|
||||
hidden_params: Union[dict, BaseModel],
|
||||
) -> Optional[float]:
|
||||
if isinstance(hidden_params, BaseModel):
|
||||
_hidden_params_dict = hidden_params.model_dump()
|
||||
|
@ -1101,6 +1102,11 @@ def default_image_cost_calculator(
|
|||
f"{quality}/{base_model_name}" if quality else base_model_name
|
||||
)
|
||||
|
||||
# gpt-image-1 models use low, medium, high quality. If user did not specify quality, use medium fot gpt-image-1 model family
|
||||
model_name_with_v2_quality = (
|
||||
f"{ImageGenerationRequestQuality.MEDIUM}/{base_model_name}"
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Looking up cost for models: {model_name_with_quality}, {base_model_name}"
|
||||
)
|
||||
|
@ -1110,6 +1116,8 @@ def default_image_cost_calculator(
|
|||
cost_info = litellm.model_cost[model_name_with_quality]
|
||||
elif base_model_name in litellm.model_cost:
|
||||
cost_info = litellm.model_cost[base_model_name]
|
||||
elif model_name_with_v2_quality in litellm.model_cost:
|
||||
cost_info = litellm.model_cost[model_name_with_v2_quality]
|
||||
else:
|
||||
# Try without provider prefix
|
||||
model_without_provider = f"{size_str}/{model.split('/')[-1]}"
|
||||
|
|
|
@ -182,6 +182,7 @@ from .types.llms.openai import (
|
|||
ChatCompletionPredictionContentParam,
|
||||
ChatCompletionUserMessage,
|
||||
HttpxBinaryResponseContent,
|
||||
ImageGenerationRequestQuality,
|
||||
)
|
||||
from .types.utils import (
|
||||
LITELLM_IMAGE_VARIATION_PROVIDERS,
|
||||
|
@ -2688,9 +2689,9 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
"aws_region_name" not in optional_params
|
||||
or optional_params["aws_region_name"] is None
|
||||
):
|
||||
optional_params[
|
||||
"aws_region_name"
|
||||
] = aws_bedrock_client.meta.region_name
|
||||
optional_params["aws_region_name"] = (
|
||||
aws_bedrock_client.meta.region_name
|
||||
)
|
||||
|
||||
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
|
||||
if bedrock_route == "converse":
|
||||
|
@ -4412,9 +4413,9 @@ def adapter_completion(
|
|||
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
|
||||
|
||||
response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore
|
||||
translated_response: Optional[
|
||||
Union[BaseModel, AdapterCompletionStreamWrapper]
|
||||
] = None
|
||||
translated_response: Optional[Union[BaseModel, AdapterCompletionStreamWrapper]] = (
|
||||
None
|
||||
)
|
||||
if isinstance(response, ModelResponse):
|
||||
translated_response = translation_obj.translate_completion_output_params(
|
||||
response=response
|
||||
|
@ -4567,7 +4568,7 @@ def image_generation( # noqa: PLR0915
|
|||
prompt: str,
|
||||
model: Optional[str] = None,
|
||||
n: Optional[int] = None,
|
||||
quality: Optional[str] = None,
|
||||
quality: Optional[Union[str, ImageGenerationRequestQuality]] = None,
|
||||
response_format: Optional[str] = None,
|
||||
size: Optional[str] = None,
|
||||
style: Optional[str] = None,
|
||||
|
@ -5834,9 +5835,9 @@ def stream_chunk_builder( # noqa: PLR0915
|
|||
]
|
||||
|
||||
if len(content_chunks) > 0:
|
||||
response["choices"][0]["message"][
|
||||
"content"
|
||||
] = processor.get_combined_content(content_chunks)
|
||||
response["choices"][0]["message"]["content"] = (
|
||||
processor.get_combined_content(content_chunks)
|
||||
)
|
||||
|
||||
reasoning_chunks = [
|
||||
chunk
|
||||
|
@ -5847,9 +5848,9 @@ def stream_chunk_builder( # noqa: PLR0915
|
|||
]
|
||||
|
||||
if len(reasoning_chunks) > 0:
|
||||
response["choices"][0]["message"][
|
||||
"reasoning_content"
|
||||
] = processor.get_combined_reasoning_content(reasoning_chunks)
|
||||
response["choices"][0]["message"]["reasoning_content"] = (
|
||||
processor.get_combined_reasoning_content(reasoning_chunks)
|
||||
)
|
||||
|
||||
audio_chunks = [
|
||||
chunk
|
||||
|
|
|
@ -824,12 +824,12 @@ class OpenAIChatCompletionChunk(ChatCompletionChunk):
|
|||
|
||||
class Hyperparameters(BaseModel):
|
||||
batch_size: Optional[Union[str, int]] = None # "Number of examples in each batch."
|
||||
learning_rate_multiplier: Optional[
|
||||
Union[str, float]
|
||||
] = None # Scaling factor for the learning rate
|
||||
n_epochs: Optional[
|
||||
Union[str, int]
|
||||
] = None # "The number of epochs to train the model for"
|
||||
learning_rate_multiplier: Optional[Union[str, float]] = (
|
||||
None # Scaling factor for the learning rate
|
||||
)
|
||||
n_epochs: Optional[Union[str, int]] = (
|
||||
None # "The number of epochs to train the model for"
|
||||
)
|
||||
|
||||
|
||||
class FineTuningJobCreate(BaseModel):
|
||||
|
@ -856,18 +856,18 @@ class FineTuningJobCreate(BaseModel):
|
|||
|
||||
model: str # "The name of the model to fine-tune."
|
||||
training_file: str # "The ID of an uploaded file that contains training data."
|
||||
hyperparameters: Optional[
|
||||
Hyperparameters
|
||||
] = None # "The hyperparameters used for the fine-tuning job."
|
||||
suffix: Optional[
|
||||
str
|
||||
] = None # "A string of up to 18 characters that will be added to your fine-tuned model name."
|
||||
validation_file: Optional[
|
||||
str
|
||||
] = None # "The ID of an uploaded file that contains validation data."
|
||||
integrations: Optional[
|
||||
List[str]
|
||||
] = None # "A list of integrations to enable for your fine-tuning job."
|
||||
hyperparameters: Optional[Hyperparameters] = (
|
||||
None # "The hyperparameters used for the fine-tuning job."
|
||||
)
|
||||
suffix: Optional[str] = (
|
||||
None # "A string of up to 18 characters that will be added to your fine-tuned model name."
|
||||
)
|
||||
validation_file: Optional[str] = (
|
||||
None # "The ID of an uploaded file that contains validation data."
|
||||
)
|
||||
integrations: Optional[List[str]] = (
|
||||
None # "A list of integrations to enable for your fine-tuning job."
|
||||
)
|
||||
seed: Optional[int] = None # "The seed controls the reproducibility of the job."
|
||||
|
||||
|
||||
|
@ -1259,3 +1259,12 @@ class OpenAIRealtimeStreamResponseBaseObject(TypedDict):
|
|||
OpenAIRealtimeStreamList = List[
|
||||
Union[OpenAIRealtimeStreamResponseBaseObject, OpenAIRealtimeStreamSessionEvents]
|
||||
]
|
||||
|
||||
|
||||
class ImageGenerationRequestQuality(str, Enum):
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
AUTO = "auto"
|
||||
STANDARD = "standard"
|
||||
HD = "hd"
|
||||
|
|
|
@ -66,8 +66,8 @@ class BaseImageGenTest(ABC):
|
|||
logged_standard_logging_payload = custom_logger.standard_logging_payload
|
||||
print("logged_standard_logging_payload", logged_standard_logging_payload)
|
||||
assert logged_standard_logging_payload is not None
|
||||
# assert logged_standard_logging_payload["response_cost"] is not None
|
||||
# assert logged_standard_logging_payload["response_cost"] > 0
|
||||
assert logged_standard_logging_payload["response_cost"] is not None
|
||||
assert logged_standard_logging_payload["response_cost"] > 0
|
||||
|
||||
from openai.types.images_response import ImagesResponse
|
||||
|
||||
|
@ -85,4 +85,4 @@ class BaseImageGenTest(ABC):
|
|||
if "Your task failed as a result of our safety system." in str(e):
|
||||
pass
|
||||
else:
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
|
@ -161,6 +161,9 @@ class TestOpenAIDalle3(BaseImageGenTest):
|
|||
def get_base_image_generation_call_args(self) -> dict:
|
||||
return {"model": "dall-e-3"}
|
||||
|
||||
class TestOpenAIGPTImage1(BaseImageGenTest):
|
||||
def get_base_image_generation_call_args(self) -> dict:
|
||||
return {"model": "gpt-image-1"}
|
||||
|
||||
class TestAzureOpenAIDalle3(BaseImageGenTest):
|
||||
def get_base_image_generation_call_args(self) -> dict:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue