diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py index e12b656ed..2185ec459 100644 --- a/litellm/llms/bedrock.py +++ b/litellm/llms/bedrock.py @@ -13,6 +13,7 @@ from enum import Enum from typing import Any, Callable, List, Optional, Union import httpx +from openai.types.image import Image import litellm from litellm.litellm_core_utils.core_helpers import map_finish_reason @@ -1413,10 +1414,10 @@ def embedding( def image_generation( model: str, prompt: str, + model_response: ImageResponse, + optional_params: dict, timeout=None, logging_obj=None, - model_response=None, - optional_params=None, aimg_generation=False, ): """ @@ -1513,9 +1514,10 @@ def image_generation( if model_response is None: model_response = ImageResponse() - image_list: List = [] + image_list: List[Image] = [] for artifact in response_body["artifacts"]: - image_dict = {"url": artifact["base64"]} + _image = Image(b64_json=artifact["base64"]) + image_list.append(_image) - model_response.data = image_dict + model_response.data = image_list return model_response diff --git a/litellm/tests/test_image_generation.py b/litellm/tests/test_image_generation.py index 71f2e761e..e59cf2865 100644 --- a/litellm/tests/test_image_generation.py +++ b/litellm/tests/test_image_generation.py @@ -158,7 +158,11 @@ def test_image_generation_bedrock(): model="bedrock/stability.stable-diffusion-xl-v1", aws_region_name="us-west-2", ) + print(f"response: {response}") + from openai.types.images_response import ImagesResponse + + ImagesResponse.model_validate(response.model_dump()) except litellm.RateLimitError as e: pass except litellm.ContentPolicyViolationError: diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 314370f02..e9247a71f 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -950,16 +950,18 @@ class ImageObject(OpenAIObject): return self.dict() -class ImageResponse(OpenAIObject): - created: Optional[int] = None +from openai.types.images_response import ImagesResponse as OpenAIImageResponse - data: Optional[List[ImageObject]] = None - - usage: Optional[dict] = None +class ImageResponse(OpenAIImageResponse): _hidden_params: dict = {} - def __init__(self, created=None, data=None, response_ms=None): + def __init__( + self, + created: Optional[int] = None, + data: Optional[list] = None, + response_ms=None, + ): if response_ms: _response_ms = response_ms else: @@ -967,14 +969,16 @@ class ImageResponse(OpenAIObject): if data: data = data else: - data = None + data = [] if created: created = created else: - created = None + created = int(time.time()) - super().__init__(data=data, created=created) + _data = {"data": data, "created": created} + + super().__init__(**_data) self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} def __contains__(self, key):