fix(bedrock.py): fix response format for bedrock image generation response

Fixes https://github.com/BerriAI/litellm/issues/5010
This commit is contained in:
Krrish Dholakia 2024-08-03 09:46:44 -07:00
parent 1d56c2f83e
commit c982ec88d8
3 changed files with 24 additions and 14 deletions

View file

@ -13,6 +13,7 @@ from enum import Enum
from typing import Any, Callable, List, Optional, Union from typing import Any, Callable, List, Optional, Union
import httpx import httpx
from openai.types.image import Image
import litellm import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
@ -1413,10 +1414,10 @@ def embedding(
def image_generation( def image_generation(
model: str, model: str,
prompt: str, prompt: str,
model_response: ImageResponse,
optional_params: dict,
timeout=None, timeout=None,
logging_obj=None, logging_obj=None,
model_response=None,
optional_params=None,
aimg_generation=False, aimg_generation=False,
): ):
""" """
@ -1513,9 +1514,10 @@ def image_generation(
if model_response is None: if model_response is None:
model_response = ImageResponse() model_response = ImageResponse()
image_list: List = [] image_list: List[Image] = []
for artifact in response_body["artifacts"]: for artifact in response_body["artifacts"]:
image_dict = {"url": artifact["base64"]} _image = Image(b64_json=artifact["base64"])
image_list.append(_image)
model_response.data = image_dict model_response.data = image_list
return model_response return model_response

View file

@ -158,7 +158,11 @@ def test_image_generation_bedrock():
model="bedrock/stability.stable-diffusion-xl-v1", model="bedrock/stability.stable-diffusion-xl-v1",
aws_region_name="us-west-2", aws_region_name="us-west-2",
) )
print(f"response: {response}") print(f"response: {response}")
from openai.types.images_response import ImagesResponse
ImagesResponse.model_validate(response.model_dump())
except litellm.RateLimitError as e: except litellm.RateLimitError as e:
pass pass
except litellm.ContentPolicyViolationError: except litellm.ContentPolicyViolationError:

View file

@ -950,16 +950,18 @@ class ImageObject(OpenAIObject):
return self.dict() return self.dict()
class ImageResponse(OpenAIObject): from openai.types.images_response import ImagesResponse as OpenAIImageResponse
created: Optional[int] = None
data: Optional[List[ImageObject]] = None
usage: Optional[dict] = None
class ImageResponse(OpenAIImageResponse):
_hidden_params: dict = {} _hidden_params: dict = {}
def __init__(self, created=None, data=None, response_ms=None): def __init__(
self,
created: Optional[int] = None,
data: Optional[list] = None,
response_ms=None,
):
if response_ms: if response_ms:
_response_ms = response_ms _response_ms = response_ms
else: else:
@ -967,14 +969,16 @@ class ImageResponse(OpenAIObject):
if data: if data:
data = data data = data
else: else:
data = None data = []
if created: if created:
created = created created = created
else: else:
created = None created = int(time.time())
super().__init__(data=data, created=created) _data = {"data": data, "created": created}
super().__init__(**_data)
self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
def __contains__(self, key): def __contains__(self, key):