diff --git a/docs/my-website/docs/image_generation.md b/docs/my-website/docs/image_generation.md index 002d95c03..7bb4d2c99 100644 --- a/docs/my-website/docs/image_generation.md +++ b/docs/my-website/docs/image_generation.md @@ -150,4 +150,20 @@ response = image_generation( model="bedrock/stability.stable-diffusion-xl-v0", ) print(f"response: {response}") +``` + +## VertexAI - Image Generation Models + +### Usage + +Use this for image generation models on VertexAI + +```python +response = litellm.image_generation( + prompt="An olympic size swimming pool", + model="vertex_ai/imagegeneration@006", + vertex_ai_project="adroit-crow-413218", + vertex_ai_location="us-central1", +) +print(f"response: {response}") ``` \ No newline at end of file diff --git a/docs/my-website/docs/providers/vertex.md b/docs/my-website/docs/providers/vertex.md index b67eb350b..32c3ea188 100644 --- a/docs/my-website/docs/providers/vertex.md +++ b/docs/my-website/docs/providers/vertex.md @@ -508,6 +508,31 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02 | text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` | | text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` | +## Image Generation Models + +Usage + +```python +response = await litellm.aimage_generation( + prompt="An olympic size swimming pool", + model="vertex_ai/imagegeneration@006", + vertex_ai_project="adroit-crow-413218", + vertex_ai_location="us-central1", +) +``` + +**Generating multiple images** + +Use the `n` parameter to pass how many images you want generated +```python +response = await litellm.aimage_generation( + prompt="An olympic size swimming pool", + model="vertex_ai/imagegeneration@006", + vertex_ai_project="adroit-crow-413218", + vertex_ai_location="us-central1", + n=1, +) +``` ## Extra diff --git a/litellm/__init__.py b/litellm/__init__.py index ac2b420d7..92610afd9 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -724,6 +724,9 @@ from .utils import ( get_supported_openai_params, get_api_base, get_first_chars_messages, + ModelResponse, + ImageResponse, + ImageObject, ) from .llms.huggingface_restapi import HuggingfaceConfig from .llms.anthropic import AnthropicConfig diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py new file mode 100644 index 000000000..b8c698c90 --- /dev/null +++ b/litellm/llms/vertex_httpx.py @@ -0,0 +1,224 @@ +import os, types +import json +from enum import Enum +import requests # type: ignore +import time +from typing import Callable, Optional, Union, List, Any, Tuple +from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason +import litellm, uuid +import httpx, inspect # type: ignore +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from .base import BaseLLM + + +class VertexAIError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + self.request = httpx.Request( + method="POST", url=" https://cloud.google.com/vertex-ai/" + ) + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + + +class VertexLLM(BaseLLM): + def __init__(self) -> None: + super().__init__() + self.access_token: Optional[str] = None + self.refresh_token: Optional[str] = None + self._credentials: Optional[Any] = None + self.project_id: Optional[str] = None + self.async_handler: Optional[AsyncHTTPHandler] = None + + def load_auth(self) -> Tuple[Any, str]: + from google.auth.transport.requests import Request # type: ignore[import-untyped] + from google.auth.credentials import Credentials # type: ignore[import-untyped] + import google.auth as google_auth + + credentials, project_id = google_auth.default( + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + + credentials.refresh(Request()) + + if not project_id: + raise ValueError("Could not resolve project_id") + + if not isinstance(project_id, str): + raise TypeError( + f"Expected project_id to be a str but got {type(project_id)}" + ) + + return credentials, project_id + + def refresh_auth(self, credentials: Any) -> None: + from google.auth.transport.requests import Request # type: ignore[import-untyped] + + credentials.refresh(Request()) + + def _prepare_request(self, request: httpx.Request) -> None: + access_token = self._ensure_access_token() + + if request.headers.get("Authorization"): + # already authenticated, nothing for us to do + return + + request.headers["Authorization"] = f"Bearer {access_token}" + + def _ensure_access_token(self) -> str: + if self.access_token is not None: + return self.access_token + + if not self._credentials: + self._credentials, project_id = self.load_auth() + if not self.project_id: + self.project_id = project_id + else: + self.refresh_auth(self._credentials) + + if not self._credentials.token: + raise RuntimeError("Could not resolve API token from the environment") + + assert isinstance(self._credentials.token, str) + return self._credentials.token + + def image_generation( + self, + prompt: str, + vertex_project: str, + vertex_location: str, + model: Optional[ + str + ] = "imagegeneration", # vertex ai uses imagegeneration as the default model + client: Optional[AsyncHTTPHandler] = None, + optional_params: Optional[dict] = None, + timeout: Optional[int] = None, + logging_obj=None, + model_response=None, + aimg_generation=False, + ): + if aimg_generation == True: + response = self.aimage_generation( + prompt=prompt, + vertex_project=vertex_project, + vertex_location=vertex_location, + model=model, + client=client, + optional_params=optional_params, + timeout=timeout, + logging_obj=logging_obj, + model_response=model_response, + ) + return response + + async def aimage_generation( + self, + prompt: str, + vertex_project: str, + vertex_location: str, + model_response: litellm.ImageResponse, + model: Optional[ + str + ] = "imagegeneration", # vertex ai uses imagegeneration as the default model + client: Optional[AsyncHTTPHandler] = None, + optional_params: Optional[dict] = None, + timeout: Optional[int] = None, + logging_obj=None, + ): + response = None + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + _httpx_timeout = httpx.Timeout(timeout) + _params["timeout"] = _httpx_timeout + else: + _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) + + self.async_handler = AsyncHTTPHandler(**_params) # type: ignore + else: + self.async_handler = client # type: ignore + + # make POST request to + # https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict" + + """ + Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218 + curl -X POST \ + -H "Authorization: Bearer $(gcloud auth print-access-token)" \ + -H "Content-Type: application/json; charset=utf-8" \ + -d { + "instances": [ + { + "prompt": "a cat" + } + ], + "parameters": { + "sampleCount": 1 + } + } \ + "https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict" + """ + auth_header = self._ensure_access_token() + optional_params = optional_params or { + "sampleCount": 1 + } # default optional params + + request_data = { + "instances": [{"prompt": prompt}], + "parameters": optional_params, + } + + request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\"" + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + + response = await self.async_handler.post( + url=url, + headers={ + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {auth_header}", + }, + data=json.dumps(request_data), + ) + + if response.status_code != 200: + raise Exception(f"Error: {response.status_code} {response.text}") + """ + Vertex AI Image generation response example: + { + "predictions": [ + { + "bytesBase64Encoded": "BASE64_IMG_BYTES", + "mimeType": "image/png" + }, + { + "mimeType": "image/png", + "bytesBase64Encoded": "BASE64_IMG_BYTES" + } + ] + } + """ + + _json_response = response.json() + _predictions = _json_response["predictions"] + + _response_data: List[litellm.ImageObject] = [] + for _prediction in _predictions: + _bytes_base64_encoded = _prediction["bytesBase64Encoded"] + image_object = litellm.ImageObject(b64_json=_bytes_base64_encoded) + _response_data.append(image_object) + + model_response.data = _response_data + + return model_response diff --git a/litellm/main.py b/litellm/main.py index 14fd5439f..7601d98a2 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -79,6 +79,7 @@ from .llms.anthropic_text import AnthropicTextCompletion from .llms.huggingface_restapi import Huggingface from .llms.predibase import PredibaseChatCompletion from .llms.bedrock_httpx import BedrockLLM +from .llms.vertex_httpx import VertexLLM from .llms.triton import TritonChatCompletion from .llms.prompt_templates.factory import ( prompt_factory, @@ -118,6 +119,7 @@ huggingface = Huggingface() predibase_chat_completions = PredibaseChatCompletion() triton_chat_completions = TritonChatCompletion() bedrock_chat_completion = BedrockLLM() +vertex_chat_completion = VertexLLM() ####### COMPLETION ENDPOINTS ################ @@ -3854,6 +3856,36 @@ def image_generation( model_response=model_response, aimg_generation=aimg_generation, ) + elif custom_llm_provider == "vertex_ai": + vertex_ai_project = ( + optional_params.pop("vertex_project", None) + or optional_params.pop("vertex_ai_project", None) + or litellm.vertex_project + or get_secret("VERTEXAI_PROJECT") + ) + vertex_ai_location = ( + optional_params.pop("vertex_location", None) + or optional_params.pop("vertex_ai_location", None) + or litellm.vertex_location + or get_secret("VERTEXAI_LOCATION") + ) + vertex_credentials = ( + optional_params.pop("vertex_credentials", None) + or optional_params.pop("vertex_ai_credentials", None) + or get_secret("VERTEXAI_CREDENTIALS") + ) + model_response = vertex_chat_completion.image_generation( + model=model, + prompt=prompt, + timeout=timeout, + logging_obj=litellm_logging_obj, + optional_params=optional_params, + model_response=model_response, + vertex_project=vertex_ai_project, + vertex_location=vertex_ai_location, + aimg_generation=aimg_generation, + ) + return model_response except Exception as e: ## Map to OpenAI Exception diff --git a/litellm/tests/test_image_generation.py b/litellm/tests/test_image_generation.py index 82068a115..35f66ad47 100644 --- a/litellm/tests/test_image_generation.py +++ b/litellm/tests/test_image_generation.py @@ -169,3 +169,36 @@ async def test_aimage_generation_bedrock_with_optional_params(): pass else: pytest.fail(f"An exception occurred - {str(e)}") + + +@pytest.mark.asyncio +async def test_aimage_generation_vertex_ai(): + from test_amazing_vertex_completion import load_vertex_ai_credentials + + litellm.set_verbose = True + + load_vertex_ai_credentials() + try: + response = await litellm.aimage_generation( + prompt="An olympic size swimming pool", + model="vertex_ai/imagegeneration@006", + vertex_ai_project="adroit-crow-413218", + vertex_ai_location="us-central1", + n=1, + ) + assert response.data is not None + assert len(response.data) > 0 + + for d in response.data: + assert isinstance(d, litellm.ImageObject) + print("data in response.data", d) + assert d.b64_json is not None + except litellm.RateLimitError as e: + pass + except litellm.ContentPolicyViolationError: + pass # Azure randomly raises these errors - skip when they occur + except Exception as e: + if "Your task failed as a result of our safety system." in str(e): + pass + else: + pytest.fail(f"An exception occurred - {str(e)}") diff --git a/litellm/utils.py b/litellm/utils.py index ac246fca6..8ac6b58d8 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -965,10 +965,54 @@ class TextCompletionResponse(OpenAIObject): setattr(self, key, value) +class ImageObject(OpenAIObject): + """ + Represents the url or the content of an image generated by the OpenAI API. + + Attributes: + b64_json: The base64-encoded JSON of the generated image, if response_format is b64_json. + url: The URL of the generated image, if response_format is url (default). + revised_prompt: The prompt that was used to generate the image, if there was any revision to the prompt. + + https://platform.openai.com/docs/api-reference/images/object + """ + + b64_json: Optional[str] = None + url: Optional[str] = None + revised_prompt: Optional[str] = None + + def __init__(self, b64_json=None, url=None, revised_prompt=None): + + super().__init__(b64_json=b64_json, url=url, revised_prompt=revised_prompt) + + def __contains__(self, key): + # Define custom behavior for the 'in' operator + return hasattr(self, key) + + def get(self, key, default=None): + # Custom .get() method to access attributes with a default value if the attribute doesn't exist + return getattr(self, key, default) + + def __getitem__(self, key): + # Allow dictionary-style access to attributes + return getattr(self, key) + + def __setitem__(self, key, value): + # Allow dictionary-style assignment of attributes + setattr(self, key, value) + + def json(self, **kwargs): + try: + return self.model_dump() # noqa + except: + # if using pydantic v1 + return self.dict() + + class ImageResponse(OpenAIObject): created: Optional[int] = None - data: Optional[list] = None + data: Optional[List[ImageObject]] = None usage: Optional[dict] = None @@ -4902,6 +4946,14 @@ def get_optional_params_image_gen( width, height = size.split("x") optional_params["width"] = int(width) optional_params["height"] = int(height) + elif custom_llm_provider == "vertex_ai": + supported_params = ["n"] + """ + All params here: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218 + """ + _check_valid_arg(supported_params=supported_params) + if n is not None: + optional_params["sampleCount"] = int(n) for k in passed_params.keys(): if k not in default_params.keys():