mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
TestOpenAIGPTImage1
This commit is contained in:
parent
36ee132514
commit
233d1d5303
5 changed files with 59 additions and 38 deletions
|
@ -57,6 +57,7 @@ from litellm.llms.vertex_ai.image_generation.cost_calculator import (
|
||||||
from litellm.responses.utils import ResponseAPILoggingUtils
|
from litellm.responses.utils import ResponseAPILoggingUtils
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
HttpxBinaryResponseContent,
|
HttpxBinaryResponseContent,
|
||||||
|
ImageGenerationRequestQuality,
|
||||||
OpenAIRealtimeStreamList,
|
OpenAIRealtimeStreamList,
|
||||||
OpenAIRealtimeStreamResponseBaseObject,
|
OpenAIRealtimeStreamResponseBaseObject,
|
||||||
OpenAIRealtimeStreamSessionEvents,
|
OpenAIRealtimeStreamSessionEvents,
|
||||||
|
@ -642,9 +643,9 @@ def completion_cost( # noqa: PLR0915
|
||||||
or isinstance(completion_response, dict)
|
or isinstance(completion_response, dict)
|
||||||
): # tts returns a custom class
|
): # tts returns a custom class
|
||||||
if isinstance(completion_response, dict):
|
if isinstance(completion_response, dict):
|
||||||
usage_obj: Optional[
|
usage_obj: Optional[Union[dict, Usage]] = (
|
||||||
Union[dict, Usage]
|
completion_response.get("usage", {})
|
||||||
] = completion_response.get("usage", {})
|
)
|
||||||
else:
|
else:
|
||||||
usage_obj = getattr(completion_response, "usage", {})
|
usage_obj = getattr(completion_response, "usage", {})
|
||||||
if isinstance(usage_obj, BaseModel) and not _is_known_usage_objects(
|
if isinstance(usage_obj, BaseModel) and not _is_known_usage_objects(
|
||||||
|
@ -913,7 +914,7 @@ def completion_cost( # noqa: PLR0915
|
||||||
|
|
||||||
|
|
||||||
def get_response_cost_from_hidden_params(
|
def get_response_cost_from_hidden_params(
|
||||||
hidden_params: Union[dict, BaseModel]
|
hidden_params: Union[dict, BaseModel],
|
||||||
) -> Optional[float]:
|
) -> Optional[float]:
|
||||||
if isinstance(hidden_params, BaseModel):
|
if isinstance(hidden_params, BaseModel):
|
||||||
_hidden_params_dict = hidden_params.model_dump()
|
_hidden_params_dict = hidden_params.model_dump()
|
||||||
|
@ -1101,6 +1102,11 @@ def default_image_cost_calculator(
|
||||||
f"{quality}/{base_model_name}" if quality else base_model_name
|
f"{quality}/{base_model_name}" if quality else base_model_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# gpt-image-1 models use low, medium, high quality. If user did not specify quality, use medium fot gpt-image-1 model family
|
||||||
|
model_name_with_v2_quality = (
|
||||||
|
f"{ImageGenerationRequestQuality.MEDIUM}/{base_model_name}"
|
||||||
|
)
|
||||||
|
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"Looking up cost for models: {model_name_with_quality}, {base_model_name}"
|
f"Looking up cost for models: {model_name_with_quality}, {base_model_name}"
|
||||||
)
|
)
|
||||||
|
@ -1110,6 +1116,8 @@ def default_image_cost_calculator(
|
||||||
cost_info = litellm.model_cost[model_name_with_quality]
|
cost_info = litellm.model_cost[model_name_with_quality]
|
||||||
elif base_model_name in litellm.model_cost:
|
elif base_model_name in litellm.model_cost:
|
||||||
cost_info = litellm.model_cost[base_model_name]
|
cost_info = litellm.model_cost[base_model_name]
|
||||||
|
elif model_name_with_v2_quality in litellm.model_cost:
|
||||||
|
cost_info = litellm.model_cost[model_name_with_v2_quality]
|
||||||
else:
|
else:
|
||||||
# Try without provider prefix
|
# Try without provider prefix
|
||||||
model_without_provider = f"{size_str}/{model.split('/')[-1]}"
|
model_without_provider = f"{size_str}/{model.split('/')[-1]}"
|
||||||
|
|
|
@ -182,6 +182,7 @@ from .types.llms.openai import (
|
||||||
ChatCompletionPredictionContentParam,
|
ChatCompletionPredictionContentParam,
|
||||||
ChatCompletionUserMessage,
|
ChatCompletionUserMessage,
|
||||||
HttpxBinaryResponseContent,
|
HttpxBinaryResponseContent,
|
||||||
|
ImageGenerationRequestQuality,
|
||||||
)
|
)
|
||||||
from .types.utils import (
|
from .types.utils import (
|
||||||
LITELLM_IMAGE_VARIATION_PROVIDERS,
|
LITELLM_IMAGE_VARIATION_PROVIDERS,
|
||||||
|
@ -2688,9 +2689,9 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
"aws_region_name" not in optional_params
|
"aws_region_name" not in optional_params
|
||||||
or optional_params["aws_region_name"] is None
|
or optional_params["aws_region_name"] is None
|
||||||
):
|
):
|
||||||
optional_params[
|
optional_params["aws_region_name"] = (
|
||||||
"aws_region_name"
|
aws_bedrock_client.meta.region_name
|
||||||
] = aws_bedrock_client.meta.region_name
|
)
|
||||||
|
|
||||||
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
|
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
|
||||||
if bedrock_route == "converse":
|
if bedrock_route == "converse":
|
||||||
|
@ -4412,9 +4413,9 @@ def adapter_completion(
|
||||||
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
|
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
|
||||||
|
|
||||||
response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore
|
response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore
|
||||||
translated_response: Optional[
|
translated_response: Optional[Union[BaseModel, AdapterCompletionStreamWrapper]] = (
|
||||||
Union[BaseModel, AdapterCompletionStreamWrapper]
|
None
|
||||||
] = None
|
)
|
||||||
if isinstance(response, ModelResponse):
|
if isinstance(response, ModelResponse):
|
||||||
translated_response = translation_obj.translate_completion_output_params(
|
translated_response = translation_obj.translate_completion_output_params(
|
||||||
response=response
|
response=response
|
||||||
|
@ -4567,7 +4568,7 @@ def image_generation( # noqa: PLR0915
|
||||||
prompt: str,
|
prompt: str,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
n: Optional[int] = None,
|
n: Optional[int] = None,
|
||||||
quality: Optional[str] = None,
|
quality: Optional[Union[str, ImageGenerationRequestQuality]] = None,
|
||||||
response_format: Optional[str] = None,
|
response_format: Optional[str] = None,
|
||||||
size: Optional[str] = None,
|
size: Optional[str] = None,
|
||||||
style: Optional[str] = None,
|
style: Optional[str] = None,
|
||||||
|
@ -5834,9 +5835,9 @@ def stream_chunk_builder( # noqa: PLR0915
|
||||||
]
|
]
|
||||||
|
|
||||||
if len(content_chunks) > 0:
|
if len(content_chunks) > 0:
|
||||||
response["choices"][0]["message"][
|
response["choices"][0]["message"]["content"] = (
|
||||||
"content"
|
processor.get_combined_content(content_chunks)
|
||||||
] = processor.get_combined_content(content_chunks)
|
)
|
||||||
|
|
||||||
reasoning_chunks = [
|
reasoning_chunks = [
|
||||||
chunk
|
chunk
|
||||||
|
@ -5847,9 +5848,9 @@ def stream_chunk_builder( # noqa: PLR0915
|
||||||
]
|
]
|
||||||
|
|
||||||
if len(reasoning_chunks) > 0:
|
if len(reasoning_chunks) > 0:
|
||||||
response["choices"][0]["message"][
|
response["choices"][0]["message"]["reasoning_content"] = (
|
||||||
"reasoning_content"
|
processor.get_combined_reasoning_content(reasoning_chunks)
|
||||||
] = processor.get_combined_reasoning_content(reasoning_chunks)
|
)
|
||||||
|
|
||||||
audio_chunks = [
|
audio_chunks = [
|
||||||
chunk
|
chunk
|
||||||
|
|
|
@ -824,12 +824,12 @@ class OpenAIChatCompletionChunk(ChatCompletionChunk):
|
||||||
|
|
||||||
class Hyperparameters(BaseModel):
|
class Hyperparameters(BaseModel):
|
||||||
batch_size: Optional[Union[str, int]] = None # "Number of examples in each batch."
|
batch_size: Optional[Union[str, int]] = None # "Number of examples in each batch."
|
||||||
learning_rate_multiplier: Optional[
|
learning_rate_multiplier: Optional[Union[str, float]] = (
|
||||||
Union[str, float]
|
None # Scaling factor for the learning rate
|
||||||
] = None # Scaling factor for the learning rate
|
)
|
||||||
n_epochs: Optional[
|
n_epochs: Optional[Union[str, int]] = (
|
||||||
Union[str, int]
|
None # "The number of epochs to train the model for"
|
||||||
] = None # "The number of epochs to train the model for"
|
)
|
||||||
|
|
||||||
|
|
||||||
class FineTuningJobCreate(BaseModel):
|
class FineTuningJobCreate(BaseModel):
|
||||||
|
@ -856,18 +856,18 @@ class FineTuningJobCreate(BaseModel):
|
||||||
|
|
||||||
model: str # "The name of the model to fine-tune."
|
model: str # "The name of the model to fine-tune."
|
||||||
training_file: str # "The ID of an uploaded file that contains training data."
|
training_file: str # "The ID of an uploaded file that contains training data."
|
||||||
hyperparameters: Optional[
|
hyperparameters: Optional[Hyperparameters] = (
|
||||||
Hyperparameters
|
None # "The hyperparameters used for the fine-tuning job."
|
||||||
] = None # "The hyperparameters used for the fine-tuning job."
|
)
|
||||||
suffix: Optional[
|
suffix: Optional[str] = (
|
||||||
str
|
None # "A string of up to 18 characters that will be added to your fine-tuned model name."
|
||||||
] = None # "A string of up to 18 characters that will be added to your fine-tuned model name."
|
)
|
||||||
validation_file: Optional[
|
validation_file: Optional[str] = (
|
||||||
str
|
None # "The ID of an uploaded file that contains validation data."
|
||||||
] = None # "The ID of an uploaded file that contains validation data."
|
)
|
||||||
integrations: Optional[
|
integrations: Optional[List[str]] = (
|
||||||
List[str]
|
None # "A list of integrations to enable for your fine-tuning job."
|
||||||
] = None # "A list of integrations to enable for your fine-tuning job."
|
)
|
||||||
seed: Optional[int] = None # "The seed controls the reproducibility of the job."
|
seed: Optional[int] = None # "The seed controls the reproducibility of the job."
|
||||||
|
|
||||||
|
|
||||||
|
@ -1259,3 +1259,12 @@ class OpenAIRealtimeStreamResponseBaseObject(TypedDict):
|
||||||
OpenAIRealtimeStreamList = List[
|
OpenAIRealtimeStreamList = List[
|
||||||
Union[OpenAIRealtimeStreamResponseBaseObject, OpenAIRealtimeStreamSessionEvents]
|
Union[OpenAIRealtimeStreamResponseBaseObject, OpenAIRealtimeStreamSessionEvents]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ImageGenerationRequestQuality(str, Enum):
|
||||||
|
LOW = "low"
|
||||||
|
MEDIUM = "medium"
|
||||||
|
HIGH = "high"
|
||||||
|
AUTO = "auto"
|
||||||
|
STANDARD = "standard"
|
||||||
|
HD = "hd"
|
||||||
|
|
|
@ -66,8 +66,8 @@ class BaseImageGenTest(ABC):
|
||||||
logged_standard_logging_payload = custom_logger.standard_logging_payload
|
logged_standard_logging_payload = custom_logger.standard_logging_payload
|
||||||
print("logged_standard_logging_payload", logged_standard_logging_payload)
|
print("logged_standard_logging_payload", logged_standard_logging_payload)
|
||||||
assert logged_standard_logging_payload is not None
|
assert logged_standard_logging_payload is not None
|
||||||
# assert logged_standard_logging_payload["response_cost"] is not None
|
assert logged_standard_logging_payload["response_cost"] is not None
|
||||||
# assert logged_standard_logging_payload["response_cost"] > 0
|
assert logged_standard_logging_payload["response_cost"] > 0
|
||||||
|
|
||||||
from openai.types.images_response import ImagesResponse
|
from openai.types.images_response import ImagesResponse
|
||||||
|
|
||||||
|
@ -85,4 +85,4 @@ class BaseImageGenTest(ABC):
|
||||||
if "Your task failed as a result of our safety system." in str(e):
|
if "Your task failed as a result of our safety system." in str(e):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
|
@ -161,6 +161,9 @@ class TestOpenAIDalle3(BaseImageGenTest):
|
||||||
def get_base_image_generation_call_args(self) -> dict:
|
def get_base_image_generation_call_args(self) -> dict:
|
||||||
return {"model": "dall-e-3"}
|
return {"model": "dall-e-3"}
|
||||||
|
|
||||||
|
class TestOpenAIGPTImage1(BaseImageGenTest):
|
||||||
|
def get_base_image_generation_call_args(self) -> dict:
|
||||||
|
return {"model": "gpt-image-1"}
|
||||||
|
|
||||||
class TestAzureOpenAIDalle3(BaseImageGenTest):
|
class TestAzureOpenAIDalle3(BaseImageGenTest):
|
||||||
def get_base_image_generation_call_args(self) -> dict:
|
def get_base_image_generation_call_args(self) -> dict:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue