diff --git a/litellm/__init__.py b/litellm/__init__.py index 83e30d775..92610afd9 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -726,6 +726,7 @@ from .utils import ( get_first_chars_messages, ModelResponse, ImageResponse, + ImageObject, ) from .llms.huggingface_restapi import HuggingfaceConfig from .llms.anthropic import AnthropicConfig diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index ca850674b..0e16c02e7 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -88,7 +88,7 @@ class VertexLLM(BaseLLM): assert isinstance(self._credentials.token, str) return self._credentials.token - async def aimage_generation( + def image_generation( self, prompt: str, vertex_project: str, @@ -101,6 +101,35 @@ class VertexLLM(BaseLLM): timeout: Optional[int] = None, logging_obj=None, model_response=None, + aimg_generation=False, + ): + if aimg_generation == True: + response = self.aimage_generation( + prompt=prompt, + vertex_project=vertex_project, + vertex_location=vertex_location, + model=model, + client=client, + optional_params=optional_params, + timeout=timeout, + logging_obj=logging_obj, + model_response=model_response, + ) + return response + + async def aimage_generation( + self, + prompt: str, + vertex_project: str, + vertex_location: str, + model_response: litellm.ImageResponse, + model: Optional[ + str + ] = "imagegeneration", # vertex ai uses imagegeneration as the default model + client: Optional[AsyncHTTPHandler] = None, + optional_params: Optional[dict] = None, + timeout: Optional[int] = None, + logging_obj=None, ): response = None if client is None: @@ -152,5 +181,31 @@ class VertexLLM(BaseLLM): if response.status_code != 200: raise Exception(f"Error: {response.status_code} {response.text}") + """ + Vertex AI Image generation response example: + { + "predictions": [ + { + "bytesBase64Encoded": "BASE64_IMG_BYTES", + "mimeType": "image/png" + }, + { + "mimeType": "image/png", + "bytesBase64Encoded": "BASE64_IMG_BYTES" + } + ] + } + """ + + _json_response = response.json() + _predictions = _json_response["predictions"] + + _response_data: List[litellm.ImageObject] = [] + for _prediction in _predictions: + _bytes_base64_encoded = _prediction["bytesBase64Encoded"] + image_object = litellm.ImageObject(b64_json=_bytes_base64_encoded) + _response_data.append(image_object) + + model_response.data = _response_data return model_response diff --git a/litellm/main.py b/litellm/main.py index 198e191fd..7601d98a2 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3874,7 +3874,7 @@ def image_generation( or optional_params.pop("vertex_ai_credentials", None) or get_secret("VERTEXAI_CREDENTIALS") ) - model_response = vertex_chat_completion.aimage_generation( # type: ignore + model_response = vertex_chat_completion.image_generation( model=model, prompt=prompt, timeout=timeout, @@ -3883,6 +3883,7 @@ def image_generation( model_response=model_response, vertex_project=vertex_ai_project, vertex_location=vertex_ai_location, + aimg_generation=aimg_generation, ) return model_response diff --git a/litellm/tests/test_image_generation.py b/litellm/tests/test_image_generation.py index 3de3ba763..886953c1a 100644 --- a/litellm/tests/test_image_generation.py +++ b/litellm/tests/test_image_generation.py @@ -178,7 +178,13 @@ async def test_aimage_generation_vertex_ai(): prompt="An olympic size swimming pool", model="vertex_ai/imagegeneration@006", ) - print(f"response: {response}") + assert response.data is not None + assert len(response.data) > 0 + + for d in response.data: + assert isinstance(d, litellm.ImageObject) + print("data in response.data", d) + assert d.b64_json is not None except litellm.RateLimitError as e: pass except litellm.ContentPolicyViolationError: diff --git a/litellm/utils.py b/litellm/utils.py index 6d0231e8f..b4a2bd618 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -965,10 +965,54 @@ class TextCompletionResponse(OpenAIObject): setattr(self, key, value) +class ImageObject(OpenAIObject): + """ + Represents the url or the content of an image generated by the OpenAI API. + + Attributes: + b64_json: The base64-encoded JSON of the generated image, if response_format is b64_json. + url: The URL of the generated image, if response_format is url (default). + revised_prompt: The prompt that was used to generate the image, if there was any revision to the prompt. + + https://platform.openai.com/docs/api-reference/images/object + """ + + b64_json: Optional[str] = None + url: Optional[str] = None + revised_prompt: Optional[str] = None + + def __init__(self, b64_json=None, url=None, revised_prompt=None): + + super().__init__(b64_json=b64_json, url=url, revised_prompt=revised_prompt) + + def __contains__(self, key): + # Define custom behavior for the 'in' operator + return hasattr(self, key) + + def get(self, key, default=None): + # Custom .get() method to access attributes with a default value if the attribute doesn't exist + return getattr(self, key, default) + + def __getitem__(self, key): + # Allow dictionary-style access to attributes + return getattr(self, key) + + def __setitem__(self, key, value): + # Allow dictionary-style assignment of attributes + setattr(self, key, value) + + def json(self, **kwargs): + try: + return self.model_dump() # noqa + except: + # if using pydantic v1 + return self.dict() + + class ImageResponse(OpenAIObject): created: Optional[int] = None - data: Optional[list] = None + data: Optional[list[ImageObject]] = None usage: Optional[dict] = None