forked from phoenix/litellm-mirror
fix(bedrock.py): fix response format for bedrock image generation response
Fixes https://github.com/BerriAI/litellm/issues/5010
This commit is contained in:
parent
1d56c2f83e
commit
c982ec88d8
3 changed files with 24 additions and 14 deletions
|
@ -13,6 +13,7 @@ from enum import Enum
|
||||||
from typing import Any, Callable, List, Optional, Union
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from openai.types.image import Image
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
|
@ -1413,10 +1414,10 @@ def embedding(
|
||||||
def image_generation(
|
def image_generation(
|
||||||
model: str,
|
model: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
model_response: ImageResponse,
|
||||||
|
optional_params: dict,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
logging_obj=None,
|
logging_obj=None,
|
||||||
model_response=None,
|
|
||||||
optional_params=None,
|
|
||||||
aimg_generation=False,
|
aimg_generation=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -1513,9 +1514,10 @@ def image_generation(
|
||||||
if model_response is None:
|
if model_response is None:
|
||||||
model_response = ImageResponse()
|
model_response = ImageResponse()
|
||||||
|
|
||||||
image_list: List = []
|
image_list: List[Image] = []
|
||||||
for artifact in response_body["artifacts"]:
|
for artifact in response_body["artifacts"]:
|
||||||
image_dict = {"url": artifact["base64"]}
|
_image = Image(b64_json=artifact["base64"])
|
||||||
|
image_list.append(_image)
|
||||||
|
|
||||||
model_response.data = image_dict
|
model_response.data = image_list
|
||||||
return model_response
|
return model_response
|
||||||
|
|
|
@ -158,7 +158,11 @@ def test_image_generation_bedrock():
|
||||||
model="bedrock/stability.stable-diffusion-xl-v1",
|
model="bedrock/stability.stable-diffusion-xl-v1",
|
||||||
aws_region_name="us-west-2",
|
aws_region_name="us-west-2",
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
|
from openai.types.images_response import ImagesResponse
|
||||||
|
|
||||||
|
ImagesResponse.model_validate(response.model_dump())
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
pass
|
pass
|
||||||
except litellm.ContentPolicyViolationError:
|
except litellm.ContentPolicyViolationError:
|
||||||
|
|
|
@ -950,16 +950,18 @@ class ImageObject(OpenAIObject):
|
||||||
return self.dict()
|
return self.dict()
|
||||||
|
|
||||||
|
|
||||||
class ImageResponse(OpenAIObject):
|
from openai.types.images_response import ImagesResponse as OpenAIImageResponse
|
||||||
created: Optional[int] = None
|
|
||||||
|
|
||||||
data: Optional[List[ImageObject]] = None
|
|
||||||
|
|
||||||
usage: Optional[dict] = None
|
|
||||||
|
|
||||||
|
class ImageResponse(OpenAIImageResponse):
|
||||||
_hidden_params: dict = {}
|
_hidden_params: dict = {}
|
||||||
|
|
||||||
def __init__(self, created=None, data=None, response_ms=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
created: Optional[int] = None,
|
||||||
|
data: Optional[list] = None,
|
||||||
|
response_ms=None,
|
||||||
|
):
|
||||||
if response_ms:
|
if response_ms:
|
||||||
_response_ms = response_ms
|
_response_ms = response_ms
|
||||||
else:
|
else:
|
||||||
|
@ -967,14 +969,16 @@ class ImageResponse(OpenAIObject):
|
||||||
if data:
|
if data:
|
||||||
data = data
|
data = data
|
||||||
else:
|
else:
|
||||||
data = None
|
data = []
|
||||||
|
|
||||||
if created:
|
if created:
|
||||||
created = created
|
created = created
|
||||||
else:
|
else:
|
||||||
created = None
|
created = int(time.time())
|
||||||
|
|
||||||
super().__init__(data=data, created=created)
|
_data = {"data": data, "created": created}
|
||||||
|
|
||||||
|
super().__init__(**_data)
|
||||||
self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue