forked from phoenix/litellm-mirror
add ImageObject
This commit is contained in:
parent
a4f906b464
commit
2519879e67
5 changed files with 111 additions and 4 deletions
|
@ -726,6 +726,7 @@ from .utils import (
|
|||
get_first_chars_messages,
|
||||
ModelResponse,
|
||||
ImageResponse,
|
||||
ImageObject,
|
||||
)
|
||||
from .llms.huggingface_restapi import HuggingfaceConfig
|
||||
from .llms.anthropic import AnthropicConfig
|
||||
|
|
|
@ -88,7 +88,7 @@ class VertexLLM(BaseLLM):
|
|||
assert isinstance(self._credentials.token, str)
|
||||
return self._credentials.token
|
||||
|
||||
async def aimage_generation(
|
||||
def image_generation(
|
||||
self,
|
||||
prompt: str,
|
||||
vertex_project: str,
|
||||
|
@ -101,6 +101,35 @@ class VertexLLM(BaseLLM):
|
|||
timeout: Optional[int] = None,
|
||||
logging_obj=None,
|
||||
model_response=None,
|
||||
aimg_generation=False,
|
||||
):
|
||||
if aimg_generation == True:
|
||||
response = self.aimage_generation(
|
||||
prompt=prompt,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
model=model,
|
||||
client=client,
|
||||
optional_params=optional_params,
|
||||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
model_response=model_response,
|
||||
)
|
||||
return response
|
||||
|
||||
async def aimage_generation(
|
||||
self,
|
||||
prompt: str,
|
||||
vertex_project: str,
|
||||
vertex_location: str,
|
||||
model_response: litellm.ImageResponse,
|
||||
model: Optional[
|
||||
str
|
||||
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[int] = None,
|
||||
logging_obj=None,
|
||||
):
|
||||
response = None
|
||||
if client is None:
|
||||
|
@ -152,5 +181,31 @@ class VertexLLM(BaseLLM):
|
|||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
"""
|
||||
Vertex AI Image generation response example:
|
||||
{
|
||||
"predictions": [
|
||||
{
|
||||
"bytesBase64Encoded": "BASE64_IMG_BYTES",
|
||||
"mimeType": "image/png"
|
||||
},
|
||||
{
|
||||
"mimeType": "image/png",
|
||||
"bytesBase64Encoded": "BASE64_IMG_BYTES"
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
_json_response = response.json()
|
||||
_predictions = _json_response["predictions"]
|
||||
|
||||
_response_data: List[litellm.ImageObject] = []
|
||||
for _prediction in _predictions:
|
||||
_bytes_base64_encoded = _prediction["bytesBase64Encoded"]
|
||||
image_object = litellm.ImageObject(b64_json=_bytes_base64_encoded)
|
||||
_response_data.append(image_object)
|
||||
|
||||
model_response.data = _response_data
|
||||
|
||||
return model_response
|
||||
|
|
|
@ -3874,7 +3874,7 @@ def image_generation(
|
|||
or optional_params.pop("vertex_ai_credentials", None)
|
||||
or get_secret("VERTEXAI_CREDENTIALS")
|
||||
)
|
||||
model_response = vertex_chat_completion.aimage_generation( # type: ignore
|
||||
model_response = vertex_chat_completion.image_generation(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
timeout=timeout,
|
||||
|
@ -3883,6 +3883,7 @@ def image_generation(
|
|||
model_response=model_response,
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_location=vertex_ai_location,
|
||||
aimg_generation=aimg_generation,
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
|
|
@ -178,7 +178,13 @@ async def test_aimage_generation_vertex_ai():
|
|||
prompt="An olympic size swimming pool",
|
||||
model="vertex_ai/imagegeneration@006",
|
||||
)
|
||||
print(f"response: {response}")
|
||||
assert response.data is not None
|
||||
assert len(response.data) > 0
|
||||
|
||||
for d in response.data:
|
||||
assert isinstance(d, litellm.ImageObject)
|
||||
print("data in response.data", d)
|
||||
assert d.b64_json is not None
|
||||
except litellm.RateLimitError as e:
|
||||
pass
|
||||
except litellm.ContentPolicyViolationError:
|
||||
|
|
|
@ -965,10 +965,54 @@ class TextCompletionResponse(OpenAIObject):
|
|||
setattr(self, key, value)
|
||||
|
||||
|
||||
class ImageObject(OpenAIObject):
|
||||
"""
|
||||
Represents the url or the content of an image generated by the OpenAI API.
|
||||
|
||||
Attributes:
|
||||
b64_json: The base64-encoded JSON of the generated image, if response_format is b64_json.
|
||||
url: The URL of the generated image, if response_format is url (default).
|
||||
revised_prompt: The prompt that was used to generate the image, if there was any revision to the prompt.
|
||||
|
||||
https://platform.openai.com/docs/api-reference/images/object
|
||||
"""
|
||||
|
||||
b64_json: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
revised_prompt: Optional[str] = None
|
||||
|
||||
def __init__(self, b64_json=None, url=None, revised_prompt=None):
|
||||
|
||||
super().__init__(b64_json=b64_json, url=url, revised_prompt=revised_prompt)
|
||||
|
||||
def __contains__(self, key):
|
||||
# Define custom behavior for the 'in' operator
|
||||
return hasattr(self, key)
|
||||
|
||||
def get(self, key, default=None):
|
||||
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||
return getattr(self, key, default)
|
||||
|
||||
def __getitem__(self, key):
|
||||
# Allow dictionary-style access to attributes
|
||||
return getattr(self, key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# Allow dictionary-style assignment of attributes
|
||||
setattr(self, key, value)
|
||||
|
||||
def json(self, **kwargs):
|
||||
try:
|
||||
return self.model_dump() # noqa
|
||||
except:
|
||||
# if using pydantic v1
|
||||
return self.dict()
|
||||
|
||||
|
||||
class ImageResponse(OpenAIObject):
|
||||
created: Optional[int] = None
|
||||
|
||||
data: Optional[list] = None
|
||||
data: Optional[list[ImageObject]] = None
|
||||
|
||||
usage: Optional[dict] = None
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue