Merge pull request #3739 from BerriAI/litellm_add_imagen_support

[FEAT] Async VertexAI Image Generation
This commit is contained in:
Ishaan Jaff 2024-05-20 14:14:43 -07:00 committed by GitHub
commit 91a89eb4ed
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 386 additions and 1 deletions

View file

@ -965,10 +965,54 @@ class TextCompletionResponse(OpenAIObject):
setattr(self, key, value)
class ImageObject(OpenAIObject):
"""
Represents the url or the content of an image generated by the OpenAI API.
Attributes:
b64_json: The base64-encoded JSON of the generated image, if response_format is b64_json.
url: The URL of the generated image, if response_format is url (default).
revised_prompt: The prompt that was used to generate the image, if there was any revision to the prompt.
https://platform.openai.com/docs/api-reference/images/object
"""
b64_json: Optional[str] = None
url: Optional[str] = None
revised_prompt: Optional[str] = None
def __init__(self, b64_json=None, url=None, revised_prompt=None):
super().__init__(b64_json=b64_json, url=url, revised_prompt=revised_prompt)
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
def json(self, **kwargs):
try:
return self.model_dump() # noqa
except:
# if using pydantic v1
return self.dict()
class ImageResponse(OpenAIObject):
created: Optional[int] = None
data: Optional[list] = None
data: Optional[List[ImageObject]] = None
usage: Optional[dict] = None
@ -4902,6 +4946,14 @@ def get_optional_params_image_gen(
width, height = size.split("x")
optional_params["width"] = int(width)
optional_params["height"] = int(height)
elif custom_llm_provider == "vertex_ai":
supported_params = ["n"]
"""
All params here: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
"""
_check_valid_arg(supported_params=supported_params)
if n is not None:
optional_params["sampleCount"] = int(n)
for k in passed_params.keys():
if k not in default_params.keys():