mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge pull request #3739 from BerriAI/litellm_add_imagen_support
[FEAT] Async VertexAI Image Generation
This commit is contained in:
commit
91a89eb4ed
7 changed files with 386 additions and 1 deletions
|
@ -150,4 +150,20 @@ response = image_generation(
|
||||||
model="bedrock/stability.stable-diffusion-xl-v0",
|
model="bedrock/stability.stable-diffusion-xl-v0",
|
||||||
)
|
)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
|
```
|
||||||
|
|
||||||
|
## VertexAI - Image Generation Models
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
Use this for image generation models on VertexAI
|
||||||
|
|
||||||
|
```python
|
||||||
|
response = litellm.image_generation(
|
||||||
|
prompt="An olympic size swimming pool",
|
||||||
|
model="vertex_ai/imagegeneration@006",
|
||||||
|
vertex_ai_project="adroit-crow-413218",
|
||||||
|
vertex_ai_location="us-central1",
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
```
|
```
|
|
@ -508,6 +508,31 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02
|
||||||
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
|
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
|
||||||
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
|
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
|
||||||
|
|
||||||
|
## Image Generation Models
|
||||||
|
|
||||||
|
Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
response = await litellm.aimage_generation(
|
||||||
|
prompt="An olympic size swimming pool",
|
||||||
|
model="vertex_ai/imagegeneration@006",
|
||||||
|
vertex_ai_project="adroit-crow-413218",
|
||||||
|
vertex_ai_location="us-central1",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Generating multiple images**
|
||||||
|
|
||||||
|
Use the `n` parameter to pass how many images you want generated
|
||||||
|
```python
|
||||||
|
response = await litellm.aimage_generation(
|
||||||
|
prompt="An olympic size swimming pool",
|
||||||
|
model="vertex_ai/imagegeneration@006",
|
||||||
|
vertex_ai_project="adroit-crow-413218",
|
||||||
|
vertex_ai_location="us-central1",
|
||||||
|
n=1,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
## Extra
|
## Extra
|
||||||
|
|
||||||
|
|
|
@ -724,6 +724,9 @@ from .utils import (
|
||||||
get_supported_openai_params,
|
get_supported_openai_params,
|
||||||
get_api_base,
|
get_api_base,
|
||||||
get_first_chars_messages,
|
get_first_chars_messages,
|
||||||
|
ModelResponse,
|
||||||
|
ImageResponse,
|
||||||
|
ImageObject,
|
||||||
)
|
)
|
||||||
from .llms.huggingface_restapi import HuggingfaceConfig
|
from .llms.huggingface_restapi import HuggingfaceConfig
|
||||||
from .llms.anthropic import AnthropicConfig
|
from .llms.anthropic import AnthropicConfig
|
||||||
|
|
224
litellm/llms/vertex_httpx.py
Normal file
224
litellm/llms/vertex_httpx.py
Normal file
|
@ -0,0 +1,224 @@
|
||||||
|
import os, types
|
||||||
|
import json
|
||||||
|
from enum import Enum
|
||||||
|
import requests # type: ignore
|
||||||
|
import time
|
||||||
|
from typing import Callable, Optional, Union, List, Any, Tuple
|
||||||
|
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
|
||||||
|
import litellm, uuid
|
||||||
|
import httpx, inspect # type: ignore
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
from .base import BaseLLM
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAIError(Exception):
|
||||||
|
def __init__(self, status_code, message):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
self.request = httpx.Request(
|
||||||
|
method="POST", url=" https://cloud.google.com/vertex-ai/"
|
||||||
|
)
|
||||||
|
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||||
|
super().__init__(
|
||||||
|
self.message
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
class VertexLLM(BaseLLM):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.access_token: Optional[str] = None
|
||||||
|
self.refresh_token: Optional[str] = None
|
||||||
|
self._credentials: Optional[Any] = None
|
||||||
|
self.project_id: Optional[str] = None
|
||||||
|
self.async_handler: Optional[AsyncHTTPHandler] = None
|
||||||
|
|
||||||
|
def load_auth(self) -> Tuple[Any, str]:
|
||||||
|
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
||||||
|
from google.auth.credentials import Credentials # type: ignore[import-untyped]
|
||||||
|
import google.auth as google_auth
|
||||||
|
|
||||||
|
credentials, project_id = google_auth.default(
|
||||||
|
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||||
|
)
|
||||||
|
|
||||||
|
credentials.refresh(Request())
|
||||||
|
|
||||||
|
if not project_id:
|
||||||
|
raise ValueError("Could not resolve project_id")
|
||||||
|
|
||||||
|
if not isinstance(project_id, str):
|
||||||
|
raise TypeError(
|
||||||
|
f"Expected project_id to be a str but got {type(project_id)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return credentials, project_id
|
||||||
|
|
||||||
|
def refresh_auth(self, credentials: Any) -> None:
|
||||||
|
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
||||||
|
|
||||||
|
credentials.refresh(Request())
|
||||||
|
|
||||||
|
def _prepare_request(self, request: httpx.Request) -> None:
|
||||||
|
access_token = self._ensure_access_token()
|
||||||
|
|
||||||
|
if request.headers.get("Authorization"):
|
||||||
|
# already authenticated, nothing for us to do
|
||||||
|
return
|
||||||
|
|
||||||
|
request.headers["Authorization"] = f"Bearer {access_token}"
|
||||||
|
|
||||||
|
def _ensure_access_token(self) -> str:
|
||||||
|
if self.access_token is not None:
|
||||||
|
return self.access_token
|
||||||
|
|
||||||
|
if not self._credentials:
|
||||||
|
self._credentials, project_id = self.load_auth()
|
||||||
|
if not self.project_id:
|
||||||
|
self.project_id = project_id
|
||||||
|
else:
|
||||||
|
self.refresh_auth(self._credentials)
|
||||||
|
|
||||||
|
if not self._credentials.token:
|
||||||
|
raise RuntimeError("Could not resolve API token from the environment")
|
||||||
|
|
||||||
|
assert isinstance(self._credentials.token, str)
|
||||||
|
return self._credentials.token
|
||||||
|
|
||||||
|
def image_generation(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
vertex_project: str,
|
||||||
|
vertex_location: str,
|
||||||
|
model: Optional[
|
||||||
|
str
|
||||||
|
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
optional_params: Optional[dict] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
logging_obj=None,
|
||||||
|
model_response=None,
|
||||||
|
aimg_generation=False,
|
||||||
|
):
|
||||||
|
if aimg_generation == True:
|
||||||
|
response = self.aimage_generation(
|
||||||
|
prompt=prompt,
|
||||||
|
vertex_project=vertex_project,
|
||||||
|
vertex_location=vertex_location,
|
||||||
|
model=model,
|
||||||
|
client=client,
|
||||||
|
optional_params=optional_params,
|
||||||
|
timeout=timeout,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
model_response=model_response,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def aimage_generation(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
vertex_project: str,
|
||||||
|
vertex_location: str,
|
||||||
|
model_response: litellm.ImageResponse,
|
||||||
|
model: Optional[
|
||||||
|
str
|
||||||
|
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
optional_params: Optional[dict] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
logging_obj=None,
|
||||||
|
):
|
||||||
|
response = None
|
||||||
|
if client is None:
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
_httpx_timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = _httpx_timeout
|
||||||
|
else:
|
||||||
|
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
|
||||||
|
self.async_handler = AsyncHTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
self.async_handler = client # type: ignore
|
||||||
|
|
||||||
|
# make POST request to
|
||||||
|
# https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict
|
||||||
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
|
||||||
|
|
||||||
|
"""
|
||||||
|
Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
|
||||||
|
curl -X POST \
|
||||||
|
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
|
||||||
|
-H "Content-Type: application/json; charset=utf-8" \
|
||||||
|
-d {
|
||||||
|
"instances": [
|
||||||
|
{
|
||||||
|
"prompt": "a cat"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"parameters": {
|
||||||
|
"sampleCount": 1
|
||||||
|
}
|
||||||
|
} \
|
||||||
|
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
|
||||||
|
"""
|
||||||
|
auth_header = self._ensure_access_token()
|
||||||
|
optional_params = optional_params or {
|
||||||
|
"sampleCount": 1
|
||||||
|
} # default optional params
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"instances": [{"prompt": prompt}],
|
||||||
|
"parameters": optional_params,
|
||||||
|
}
|
||||||
|
|
||||||
|
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=prompt,
|
||||||
|
api_key=None,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": optional_params,
|
||||||
|
"request_str": request_str,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await self.async_handler.post(
|
||||||
|
url=url,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json; charset=utf-8",
|
||||||
|
"Authorization": f"Bearer {auth_header}",
|
||||||
|
},
|
||||||
|
data=json.dumps(request_data),
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||||
|
"""
|
||||||
|
Vertex AI Image generation response example:
|
||||||
|
{
|
||||||
|
"predictions": [
|
||||||
|
{
|
||||||
|
"bytesBase64Encoded": "BASE64_IMG_BYTES",
|
||||||
|
"mimeType": "image/png"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"mimeType": "image/png",
|
||||||
|
"bytesBase64Encoded": "BASE64_IMG_BYTES"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
_json_response = response.json()
|
||||||
|
_predictions = _json_response["predictions"]
|
||||||
|
|
||||||
|
_response_data: List[litellm.ImageObject] = []
|
||||||
|
for _prediction in _predictions:
|
||||||
|
_bytes_base64_encoded = _prediction["bytesBase64Encoded"]
|
||||||
|
image_object = litellm.ImageObject(b64_json=_bytes_base64_encoded)
|
||||||
|
_response_data.append(image_object)
|
||||||
|
|
||||||
|
model_response.data = _response_data
|
||||||
|
|
||||||
|
return model_response
|
|
@ -79,6 +79,7 @@ from .llms.anthropic_text import AnthropicTextCompletion
|
||||||
from .llms.huggingface_restapi import Huggingface
|
from .llms.huggingface_restapi import Huggingface
|
||||||
from .llms.predibase import PredibaseChatCompletion
|
from .llms.predibase import PredibaseChatCompletion
|
||||||
from .llms.bedrock_httpx import BedrockLLM
|
from .llms.bedrock_httpx import BedrockLLM
|
||||||
|
from .llms.vertex_httpx import VertexLLM
|
||||||
from .llms.triton import TritonChatCompletion
|
from .llms.triton import TritonChatCompletion
|
||||||
from .llms.prompt_templates.factory import (
|
from .llms.prompt_templates.factory import (
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
|
@ -118,6 +119,7 @@ huggingface = Huggingface()
|
||||||
predibase_chat_completions = PredibaseChatCompletion()
|
predibase_chat_completions = PredibaseChatCompletion()
|
||||||
triton_chat_completions = TritonChatCompletion()
|
triton_chat_completions = TritonChatCompletion()
|
||||||
bedrock_chat_completion = BedrockLLM()
|
bedrock_chat_completion = BedrockLLM()
|
||||||
|
vertex_chat_completion = VertexLLM()
|
||||||
####### COMPLETION ENDPOINTS ################
|
####### COMPLETION ENDPOINTS ################
|
||||||
|
|
||||||
|
|
||||||
|
@ -3854,6 +3856,36 @@ def image_generation(
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
aimg_generation=aimg_generation,
|
aimg_generation=aimg_generation,
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "vertex_ai":
|
||||||
|
vertex_ai_project = (
|
||||||
|
optional_params.pop("vertex_project", None)
|
||||||
|
or optional_params.pop("vertex_ai_project", None)
|
||||||
|
or litellm.vertex_project
|
||||||
|
or get_secret("VERTEXAI_PROJECT")
|
||||||
|
)
|
||||||
|
vertex_ai_location = (
|
||||||
|
optional_params.pop("vertex_location", None)
|
||||||
|
or optional_params.pop("vertex_ai_location", None)
|
||||||
|
or litellm.vertex_location
|
||||||
|
or get_secret("VERTEXAI_LOCATION")
|
||||||
|
)
|
||||||
|
vertex_credentials = (
|
||||||
|
optional_params.pop("vertex_credentials", None)
|
||||||
|
or optional_params.pop("vertex_ai_credentials", None)
|
||||||
|
or get_secret("VERTEXAI_CREDENTIALS")
|
||||||
|
)
|
||||||
|
model_response = vertex_chat_completion.image_generation(
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
timeout=timeout,
|
||||||
|
logging_obj=litellm_logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model_response=model_response,
|
||||||
|
vertex_project=vertex_ai_project,
|
||||||
|
vertex_location=vertex_ai_location,
|
||||||
|
aimg_generation=aimg_generation,
|
||||||
|
)
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
## Map to OpenAI Exception
|
## Map to OpenAI Exception
|
||||||
|
|
|
@ -169,3 +169,36 @@ async def test_aimage_generation_bedrock_with_optional_params():
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aimage_generation_vertex_ai():
|
||||||
|
from test_amazing_vertex_completion import load_vertex_ai_credentials
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
try:
|
||||||
|
response = await litellm.aimage_generation(
|
||||||
|
prompt="An olympic size swimming pool",
|
||||||
|
model="vertex_ai/imagegeneration@006",
|
||||||
|
vertex_ai_project="adroit-crow-413218",
|
||||||
|
vertex_ai_location="us-central1",
|
||||||
|
n=1,
|
||||||
|
)
|
||||||
|
assert response.data is not None
|
||||||
|
assert len(response.data) > 0
|
||||||
|
|
||||||
|
for d in response.data:
|
||||||
|
assert isinstance(d, litellm.ImageObject)
|
||||||
|
print("data in response.data", d)
|
||||||
|
assert d.b64_json is not None
|
||||||
|
except litellm.RateLimitError as e:
|
||||||
|
pass
|
||||||
|
except litellm.ContentPolicyViolationError:
|
||||||
|
pass # Azure randomly raises these errors - skip when they occur
|
||||||
|
except Exception as e:
|
||||||
|
if "Your task failed as a result of our safety system." in str(e):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
|
@ -965,10 +965,54 @@ class TextCompletionResponse(OpenAIObject):
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageObject(OpenAIObject):
|
||||||
|
"""
|
||||||
|
Represents the url or the content of an image generated by the OpenAI API.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
b64_json: The base64-encoded JSON of the generated image, if response_format is b64_json.
|
||||||
|
url: The URL of the generated image, if response_format is url (default).
|
||||||
|
revised_prompt: The prompt that was used to generate the image, if there was any revision to the prompt.
|
||||||
|
|
||||||
|
https://platform.openai.com/docs/api-reference/images/object
|
||||||
|
"""
|
||||||
|
|
||||||
|
b64_json: Optional[str] = None
|
||||||
|
url: Optional[str] = None
|
||||||
|
revised_prompt: Optional[str] = None
|
||||||
|
|
||||||
|
def __init__(self, b64_json=None, url=None, revised_prompt=None):
|
||||||
|
|
||||||
|
super().__init__(b64_json=b64_json, url=url, revised_prompt=revised_prompt)
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
# Define custom behavior for the 'in' operator
|
||||||
|
return hasattr(self, key)
|
||||||
|
|
||||||
|
def get(self, key, default=None):
|
||||||
|
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||||
|
return getattr(self, key, default)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
# Allow dictionary-style access to attributes
|
||||||
|
return getattr(self, key)
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
# Allow dictionary-style assignment of attributes
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
def json(self, **kwargs):
|
||||||
|
try:
|
||||||
|
return self.model_dump() # noqa
|
||||||
|
except:
|
||||||
|
# if using pydantic v1
|
||||||
|
return self.dict()
|
||||||
|
|
||||||
|
|
||||||
class ImageResponse(OpenAIObject):
|
class ImageResponse(OpenAIObject):
|
||||||
created: Optional[int] = None
|
created: Optional[int] = None
|
||||||
|
|
||||||
data: Optional[list] = None
|
data: Optional[List[ImageObject]] = None
|
||||||
|
|
||||||
usage: Optional[dict] = None
|
usage: Optional[dict] = None
|
||||||
|
|
||||||
|
@ -4902,6 +4946,14 @@ def get_optional_params_image_gen(
|
||||||
width, height = size.split("x")
|
width, height = size.split("x")
|
||||||
optional_params["width"] = int(width)
|
optional_params["width"] = int(width)
|
||||||
optional_params["height"] = int(height)
|
optional_params["height"] = int(height)
|
||||||
|
elif custom_llm_provider == "vertex_ai":
|
||||||
|
supported_params = ["n"]
|
||||||
|
"""
|
||||||
|
All params here: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
|
||||||
|
"""
|
||||||
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
if n is not None:
|
||||||
|
optional_params["sampleCount"] = int(n)
|
||||||
|
|
||||||
for k in passed_params.keys():
|
for k in passed_params.keys():
|
||||||
if k not in default_params.keys():
|
if k not in default_params.keys():
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue