From 5ba5f15b56f411eabcec83b6f34f3274aa7971ff Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 20 May 2024 08:14:43 -0700 Subject: [PATCH 01/17] test - test_aimage_generation_vertex_ai --- litellm/tests/test_image_generation.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/litellm/tests/test_image_generation.py b/litellm/tests/test_image_generation.py index 82068a1156..37e560b240 100644 --- a/litellm/tests/test_image_generation.py +++ b/litellm/tests/test_image_generation.py @@ -169,3 +169,22 @@ async def test_aimage_generation_bedrock_with_optional_params(): pass else: pytest.fail(f"An exception occurred - {str(e)}") + + +@pytest.mark.asyncio +async def test_aimage_generation_vertex_ai(): + try: + response = await litellm.aimage_generation( + prompt="A cute baby sea otter", + model="vertex_ai/imagegeneration@006", + ) + print(f"response: {response}") + except litellm.RateLimitError as e: + pass + except litellm.ContentPolicyViolationError: + pass # Azure randomly raises these errors - skip when they occur + except Exception as e: + if "Your task failed as a result of our safety system." in str(e): + pass + else: + pytest.fail(f"An exception occurred - {str(e)}") From 24951d44a4d9061fd638b619918a8a6c17718b03 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 20 May 2024 09:51:15 -0700 Subject: [PATCH 02/17] feat - working httpx requests vertex ai image gen --- litellm/llms/vertex_httpx.py | 156 +++++++++++++++++++++++++ litellm/main.py | 31 +++++ litellm/tests/test_image_generation.py | 2 +- 3 files changed, 188 insertions(+), 1 deletion(-) create mode 100644 litellm/llms/vertex_httpx.py diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py new file mode 100644 index 0000000000..ca850674b8 --- /dev/null +++ b/litellm/llms/vertex_httpx.py @@ -0,0 +1,156 @@ +import os, types +import json +from enum import Enum +import requests # type: ignore +import time +from typing import Callable, Optional, Union, List +from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason +import litellm, uuid +import httpx, inspect # type: ignore +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from .base import BaseLLM + + +class VertexAIError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + self.request = httpx.Request( + method="POST", url=" https://cloud.google.com/vertex-ai/" + ) + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + + +class VertexLLM(BaseLLM): + from google.auth.credentials import Credentials # type: ignore[import-untyped] + + def __init__(self) -> None: + from google.auth.credentials import Credentials # type: ignore[import-untyped] + + super().__init__() + self.access_token: Optional[str] = None + self.refresh_token: Optional[str] = None + self._credentials: Optional[Credentials] = None + self.project_id: Optional[str] = None + + def load_auth(self) -> tuple[Credentials, str]: + from google.auth.transport.requests import Request # type: ignore[import-untyped] + from google.auth.credentials import Credentials # type: ignore[import-untyped] + import google.auth as google_auth + + credentials, project_id = google_auth.default( + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + + credentials.refresh(Request()) + + if not project_id: + raise ValueError("Could not resolve project_id") + + if not isinstance(project_id, str): + raise TypeError( + f"Expected project_id to be a str but got {type(project_id)}" + ) + + return credentials, project_id + + def refresh_auth(self, credentials: Credentials) -> None: + from google.auth.transport.requests import Request # type: ignore[import-untyped] + + credentials.refresh(Request()) + + def _prepare_request(self, request: httpx.Request) -> None: + access_token = self._ensure_access_token() + + if request.headers.get("Authorization"): + # already authenticated, nothing for us to do + return + + request.headers["Authorization"] = f"Bearer {access_token}" + + def _ensure_access_token(self) -> str: + if self.access_token is not None: + return self.access_token + + if not self._credentials: + self._credentials, project_id = self.load_auth() + if not self.project_id: + self.project_id = project_id + else: + self.refresh_auth(self._credentials) + + if not self._credentials.token: + raise RuntimeError("Could not resolve API token from the environment") + + assert isinstance(self._credentials.token, str) + return self._credentials.token + + async def aimage_generation( + self, + prompt: str, + vertex_project: str, + vertex_location: str, + model: Optional[ + str + ] = "imagegeneration", # vertex ai uses imagegeneration as the default model + client: Optional[AsyncHTTPHandler] = None, + optional_params: Optional[dict] = None, + timeout: Optional[int] = None, + logging_obj=None, + model_response=None, + ): + response = None + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + _httpx_timeout = httpx.Timeout(timeout) + _params["timeout"] = _httpx_timeout + client = AsyncHTTPHandler(**_params) # type: ignore + else: + client = client # type: ignore + + # make POST request to + # https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict" + + """ + Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218 + curl -X POST \ + -H "Authorization: Bearer $(gcloud auth print-access-token)" \ + -H "Content-Type: application/json; charset=utf-8" \ + -d { + "instances": [ + { + "prompt": "a cat" + } + ] + } \ + "https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict" + """ + + import vertexai + + auth_header = self._ensure_access_token() + + request_data = { + "instances": [{"prompt": prompt}], + "parameters": {"sampleCount": 1}, + } + + response = await client.post( + url=url, + headers={ + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {auth_header}", + }, + data=json.dumps(request_data), + ) + + if response.status_code != 200: + raise Exception(f"Error: {response.status_code} {response.text}") + + return model_response diff --git a/litellm/main.py b/litellm/main.py index 14fd5439ff..198e191fd8 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -79,6 +79,7 @@ from .llms.anthropic_text import AnthropicTextCompletion from .llms.huggingface_restapi import Huggingface from .llms.predibase import PredibaseChatCompletion from .llms.bedrock_httpx import BedrockLLM +from .llms.vertex_httpx import VertexLLM from .llms.triton import TritonChatCompletion from .llms.prompt_templates.factory import ( prompt_factory, @@ -118,6 +119,7 @@ huggingface = Huggingface() predibase_chat_completions = PredibaseChatCompletion() triton_chat_completions = TritonChatCompletion() bedrock_chat_completion = BedrockLLM() +vertex_chat_completion = VertexLLM() ####### COMPLETION ENDPOINTS ################ @@ -3854,6 +3856,35 @@ def image_generation( model_response=model_response, aimg_generation=aimg_generation, ) + elif custom_llm_provider == "vertex_ai": + vertex_ai_project = ( + optional_params.pop("vertex_project", None) + or optional_params.pop("vertex_ai_project", None) + or litellm.vertex_project + or get_secret("VERTEXAI_PROJECT") + ) + vertex_ai_location = ( + optional_params.pop("vertex_location", None) + or optional_params.pop("vertex_ai_location", None) + or litellm.vertex_location + or get_secret("VERTEXAI_LOCATION") + ) + vertex_credentials = ( + optional_params.pop("vertex_credentials", None) + or optional_params.pop("vertex_ai_credentials", None) + or get_secret("VERTEXAI_CREDENTIALS") + ) + model_response = vertex_chat_completion.aimage_generation( # type: ignore + model=model, + prompt=prompt, + timeout=timeout, + logging_obj=litellm_logging_obj, + optional_params=optional_params, + model_response=model_response, + vertex_project=vertex_ai_project, + vertex_location=vertex_ai_location, + ) + return model_response except Exception as e: ## Map to OpenAI Exception diff --git a/litellm/tests/test_image_generation.py b/litellm/tests/test_image_generation.py index 37e560b240..3de3ba7639 100644 --- a/litellm/tests/test_image_generation.py +++ b/litellm/tests/test_image_generation.py @@ -175,7 +175,7 @@ async def test_aimage_generation_bedrock_with_optional_params(): async def test_aimage_generation_vertex_ai(): try: response = await litellm.aimage_generation( - prompt="A cute baby sea otter", + prompt="An olympic size swimming pool", model="vertex_ai/imagegeneration@006", ) print(f"response: {response}") From a4f906b464e134fbb49fab7b1efced1268c22ec6 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 20 May 2024 10:09:41 -0700 Subject: [PATCH 03/17] feat - add litellm.ImageResponse --- litellm/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/litellm/__init__.py b/litellm/__init__.py index ac2b420d71..83e30d7754 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -724,6 +724,8 @@ from .utils import ( get_supported_openai_params, get_api_base, get_first_chars_messages, + ModelResponse, + ImageResponse, ) from .llms.huggingface_restapi import HuggingfaceConfig from .llms.anthropic import AnthropicConfig From 2519879e67ceac6cc926b2bdeb2d4d3d7bc9d7dc Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 20 May 2024 10:45:37 -0700 Subject: [PATCH 04/17] add ImageObject --- litellm/__init__.py | 1 + litellm/llms/vertex_httpx.py | 57 +++++++++++++++++++++++++- litellm/main.py | 3 +- litellm/tests/test_image_generation.py | 8 +++- litellm/utils.py | 46 ++++++++++++++++++++- 5 files changed, 111 insertions(+), 4 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index 83e30d7754..92610afd9d 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -726,6 +726,7 @@ from .utils import ( get_first_chars_messages, ModelResponse, ImageResponse, + ImageObject, ) from .llms.huggingface_restapi import HuggingfaceConfig from .llms.anthropic import AnthropicConfig diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index ca850674b8..0e16c02e72 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -88,7 +88,7 @@ class VertexLLM(BaseLLM): assert isinstance(self._credentials.token, str) return self._credentials.token - async def aimage_generation( + def image_generation( self, prompt: str, vertex_project: str, @@ -101,6 +101,35 @@ class VertexLLM(BaseLLM): timeout: Optional[int] = None, logging_obj=None, model_response=None, + aimg_generation=False, + ): + if aimg_generation == True: + response = self.aimage_generation( + prompt=prompt, + vertex_project=vertex_project, + vertex_location=vertex_location, + model=model, + client=client, + optional_params=optional_params, + timeout=timeout, + logging_obj=logging_obj, + model_response=model_response, + ) + return response + + async def aimage_generation( + self, + prompt: str, + vertex_project: str, + vertex_location: str, + model_response: litellm.ImageResponse, + model: Optional[ + str + ] = "imagegeneration", # vertex ai uses imagegeneration as the default model + client: Optional[AsyncHTTPHandler] = None, + optional_params: Optional[dict] = None, + timeout: Optional[int] = None, + logging_obj=None, ): response = None if client is None: @@ -152,5 +181,31 @@ class VertexLLM(BaseLLM): if response.status_code != 200: raise Exception(f"Error: {response.status_code} {response.text}") + """ + Vertex AI Image generation response example: + { + "predictions": [ + { + "bytesBase64Encoded": "BASE64_IMG_BYTES", + "mimeType": "image/png" + }, + { + "mimeType": "image/png", + "bytesBase64Encoded": "BASE64_IMG_BYTES" + } + ] + } + """ + + _json_response = response.json() + _predictions = _json_response["predictions"] + + _response_data: List[litellm.ImageObject] = [] + for _prediction in _predictions: + _bytes_base64_encoded = _prediction["bytesBase64Encoded"] + image_object = litellm.ImageObject(b64_json=_bytes_base64_encoded) + _response_data.append(image_object) + + model_response.data = _response_data return model_response diff --git a/litellm/main.py b/litellm/main.py index 198e191fd8..7601d98a2b 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3874,7 +3874,7 @@ def image_generation( or optional_params.pop("vertex_ai_credentials", None) or get_secret("VERTEXAI_CREDENTIALS") ) - model_response = vertex_chat_completion.aimage_generation( # type: ignore + model_response = vertex_chat_completion.image_generation( model=model, prompt=prompt, timeout=timeout, @@ -3883,6 +3883,7 @@ def image_generation( model_response=model_response, vertex_project=vertex_ai_project, vertex_location=vertex_ai_location, + aimg_generation=aimg_generation, ) return model_response diff --git a/litellm/tests/test_image_generation.py b/litellm/tests/test_image_generation.py index 3de3ba7639..886953c1aa 100644 --- a/litellm/tests/test_image_generation.py +++ b/litellm/tests/test_image_generation.py @@ -178,7 +178,13 @@ async def test_aimage_generation_vertex_ai(): prompt="An olympic size swimming pool", model="vertex_ai/imagegeneration@006", ) - print(f"response: {response}") + assert response.data is not None + assert len(response.data) > 0 + + for d in response.data: + assert isinstance(d, litellm.ImageObject) + print("data in response.data", d) + assert d.b64_json is not None except litellm.RateLimitError as e: pass except litellm.ContentPolicyViolationError: diff --git a/litellm/utils.py b/litellm/utils.py index 6d0231e8f2..b4a2bd618f 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -965,10 +965,54 @@ class TextCompletionResponse(OpenAIObject): setattr(self, key, value) +class ImageObject(OpenAIObject): + """ + Represents the url or the content of an image generated by the OpenAI API. + + Attributes: + b64_json: The base64-encoded JSON of the generated image, if response_format is b64_json. + url: The URL of the generated image, if response_format is url (default). + revised_prompt: The prompt that was used to generate the image, if there was any revision to the prompt. + + https://platform.openai.com/docs/api-reference/images/object + """ + + b64_json: Optional[str] = None + url: Optional[str] = None + revised_prompt: Optional[str] = None + + def __init__(self, b64_json=None, url=None, revised_prompt=None): + + super().__init__(b64_json=b64_json, url=url, revised_prompt=revised_prompt) + + def __contains__(self, key): + # Define custom behavior for the 'in' operator + return hasattr(self, key) + + def get(self, key, default=None): + # Custom .get() method to access attributes with a default value if the attribute doesn't exist + return getattr(self, key, default) + + def __getitem__(self, key): + # Allow dictionary-style access to attributes + return getattr(self, key) + + def __setitem__(self, key, value): + # Allow dictionary-style assignment of attributes + setattr(self, key, value) + + def json(self, **kwargs): + try: + return self.model_dump() # noqa + except: + # if using pydantic v1 + return self.dict() + + class ImageResponse(OpenAIObject): created: Optional[int] = None - data: Optional[list] = None + data: Optional[list[ImageObject]] = None usage: Optional[dict] = None From 2da89a0c8e9b6ceb1ca23eb0b14791c16bab75ec Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 20 May 2024 10:51:25 -0700 Subject: [PATCH 05/17] fix vertex test --- litellm/tests/test_image_generation.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/litellm/tests/test_image_generation.py b/litellm/tests/test_image_generation.py index 886953c1aa..6acb28e5b4 100644 --- a/litellm/tests/test_image_generation.py +++ b/litellm/tests/test_image_generation.py @@ -173,6 +173,9 @@ async def test_aimage_generation_bedrock_with_optional_params(): @pytest.mark.asyncio async def test_aimage_generation_vertex_ai(): + from test_amazing_vertex_completion import load_vertex_ai_credentials + + load_vertex_ai_credentials() try: response = await litellm.aimage_generation( prompt="An olympic size swimming pool", From 655478e8dcf2982349bc072b95386a83568f869c Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 20 May 2024 10:55:10 -0700 Subject: [PATCH 06/17] fix python3.8 error --- litellm/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/utils.py b/litellm/utils.py index b4a2bd618f..3dac33e564 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1012,7 +1012,7 @@ class ImageObject(OpenAIObject): class ImageResponse(OpenAIObject): created: Optional[int] = None - data: Optional[list[ImageObject]] = None + data: Optional[List[ImageObject]] = None usage: Optional[dict] = None From d50d552e5a238a988b0c7369a4a576318016f518 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 20 May 2024 11:03:28 -0700 Subject: [PATCH 07/17] fix python 3.8 import --- litellm/llms/vertex_httpx.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index 0e16c02e72..f0b2cfcb3f 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -3,7 +3,7 @@ import json from enum import Enum import requests # type: ignore import time -from typing import Callable, Optional, Union, List +from typing import Callable, Optional, Union, List, Any from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason import litellm, uuid import httpx, inspect # type: ignore @@ -25,8 +25,6 @@ class VertexAIError(Exception): class VertexLLM(BaseLLM): - from google.auth.credentials import Credentials # type: ignore[import-untyped] - def __init__(self) -> None: from google.auth.credentials import Credentials # type: ignore[import-untyped] @@ -36,7 +34,7 @@ class VertexLLM(BaseLLM): self._credentials: Optional[Credentials] = None self.project_id: Optional[str] = None - def load_auth(self) -> tuple[Credentials, str]: + def load_auth(self) -> tuple[Any, str]: from google.auth.transport.requests import Request # type: ignore[import-untyped] from google.auth.credentials import Credentials # type: ignore[import-untyped] import google.auth as google_auth @@ -57,7 +55,7 @@ class VertexLLM(BaseLLM): return credentials, project_id - def refresh_auth(self, credentials: Credentials) -> None: + def refresh_auth(self, credentials: Any) -> None: from google.auth.transport.requests import Request # type: ignore[import-untyped] credentials.refresh(Request()) From 91f8443381d152b6b29c4955e00a4e51d505e79e Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 20 May 2024 11:11:14 -0700 Subject: [PATCH 08/17] fix add debug to vertex httpx image --- litellm/llms/vertex_httpx.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index f0b2cfcb3f..4fb554fe49 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -168,6 +168,16 @@ class VertexLLM(BaseLLM): "parameters": {"sampleCount": 1}, } + request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\"" + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + response = await client.post( url=url, headers={ From 6ddc9873e5fbf36ed0ea9360cf39650c72ee1b04 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 20 May 2024 11:14:16 -0700 Subject: [PATCH 09/17] test vertex image gen test --- litellm/tests/test_image_generation.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/litellm/tests/test_image_generation.py b/litellm/tests/test_image_generation.py index 6acb28e5b4..4a5a8dac91 100644 --- a/litellm/tests/test_image_generation.py +++ b/litellm/tests/test_image_generation.py @@ -175,11 +175,14 @@ async def test_aimage_generation_bedrock_with_optional_params(): async def test_aimage_generation_vertex_ai(): from test_amazing_vertex_completion import load_vertex_ai_credentials + litellm.set_verbose = True + load_vertex_ai_credentials() try: response = await litellm.aimage_generation( prompt="An olympic size swimming pool", model="vertex_ai/imagegeneration@006", + vertex_ai_project="adroit-crow-413218", ) assert response.data is not None assert len(response.data) > 0 From 571d4cf569f6f3320c3daca6e9c5f8e5b80f5181 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 20 May 2024 12:10:28 -0700 Subject: [PATCH 10/17] test - fix test_aimage_generation_vertex_ai --- litellm/tests/test_image_generation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litellm/tests/test_image_generation.py b/litellm/tests/test_image_generation.py index 4a5a8dac91..9fe32544bd 100644 --- a/litellm/tests/test_image_generation.py +++ b/litellm/tests/test_image_generation.py @@ -183,6 +183,7 @@ async def test_aimage_generation_vertex_ai(): prompt="An olympic size swimming pool", model="vertex_ai/imagegeneration@006", vertex_ai_project="adroit-crow-413218", + vertex_ai_location="us-central1", ) assert response.data is not None assert len(response.data) > 0 From aa0ed8238b8f76106ebca76f7d4ef7f2e27e1aa0 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 20 May 2024 12:18:31 -0700 Subject: [PATCH 11/17] docs - image generation vertex --- docs/my-website/docs/image_generation.md | 16 ++++++++++++++++ docs/my-website/docs/providers/vertex.md | 12 ++++++++++++ 2 files changed, 28 insertions(+) diff --git a/docs/my-website/docs/image_generation.md b/docs/my-website/docs/image_generation.md index 002d95c030..7bb4d2c991 100644 --- a/docs/my-website/docs/image_generation.md +++ b/docs/my-website/docs/image_generation.md @@ -150,4 +150,20 @@ response = image_generation( model="bedrock/stability.stable-diffusion-xl-v0", ) print(f"response: {response}") +``` + +## VertexAI - Image Generation Models + +### Usage + +Use this for image generation models on VertexAI + +```python +response = litellm.image_generation( + prompt="An olympic size swimming pool", + model="vertex_ai/imagegeneration@006", + vertex_ai_project="adroit-crow-413218", + vertex_ai_location="us-central1", +) +print(f"response: {response}") ``` \ No newline at end of file diff --git a/docs/my-website/docs/providers/vertex.md b/docs/my-website/docs/providers/vertex.md index b67eb350b4..dc0ef48b48 100644 --- a/docs/my-website/docs/providers/vertex.md +++ b/docs/my-website/docs/providers/vertex.md @@ -508,6 +508,18 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02 | text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` | | text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` | +## Image Generation Models + +Usage + +```python +response = await litellm.aimage_generation( + prompt="An olympic size swimming pool", + model="vertex_ai/imagegeneration@006", + vertex_ai_project="adroit-crow-413218", + vertex_ai_location="us-central1", +) +``` ## Extra From dabaf5f2977719fd49c657a126e0b87e05594785 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 20 May 2024 12:21:02 -0700 Subject: [PATCH 12/17] fix python 3.8 Tuple --- litellm/llms/vertex_httpx.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index 4fb554fe49..9fc9080b08 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -3,7 +3,7 @@ import json from enum import Enum import requests # type: ignore import time -from typing import Callable, Optional, Union, List, Any +from typing import Callable, Optional, Union, List, Any, Tuple from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason import litellm, uuid import httpx, inspect # type: ignore @@ -34,7 +34,7 @@ class VertexLLM(BaseLLM): self._credentials: Optional[Credentials] = None self.project_id: Optional[str] = None - def load_auth(self) -> tuple[Any, str]: + def load_auth(self) -> Tuple[Any, str]: from google.auth.transport.requests import Request # type: ignore[import-untyped] from google.auth.credentials import Credentials # type: ignore[import-untyped] import google.auth as google_auth From 1fe3900800a4ffc0411b73b3ad66f23bc1fa67fe Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 20 May 2024 12:23:27 -0700 Subject: [PATCH 13/17] fix python 3.8 --- litellm/llms/vertex_httpx.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index 9fc9080b08..61920d4e6e 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -26,12 +26,10 @@ class VertexAIError(Exception): class VertexLLM(BaseLLM): def __init__(self) -> None: - from google.auth.credentials import Credentials # type: ignore[import-untyped] - super().__init__() self.access_token: Optional[str] = None self.refresh_token: Optional[str] = None - self._credentials: Optional[Credentials] = None + self._credentials: Optional[Any] = None self.project_id: Optional[str] = None def load_auth(self) -> Tuple[Any, str]: From 11c9780ff05d2a4661749bdf98092a0cc956f4e5 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 20 May 2024 13:11:10 -0700 Subject: [PATCH 14/17] fix self.async_handler --- litellm/llms/vertex_httpx.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index 61920d4e6e..e7b31b1554 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -31,6 +31,7 @@ class VertexLLM(BaseLLM): self.refresh_token: Optional[str] = None self._credentials: Optional[Any] = None self.project_id: Optional[str] = None + self.async_handler: Optional[AsyncHTTPHandler] = None def load_auth(self) -> Tuple[Any, str]: from google.auth.transport.requests import Request # type: ignore[import-untyped] @@ -134,9 +135,9 @@ class VertexLLM(BaseLLM): if isinstance(timeout, float) or isinstance(timeout, int): _httpx_timeout = httpx.Timeout(timeout) _params["timeout"] = _httpx_timeout - client = AsyncHTTPHandler(**_params) # type: ignore + self.async_handler = AsyncHTTPHandler(**_params) # type: ignore else: - client = client # type: ignore + self.async_handler = client # type: ignore # make POST request to # https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict @@ -176,7 +177,7 @@ class VertexLLM(BaseLLM): }, ) - response = await client.post( + response = await self.async_handler.post( url=url, headers={ "Content-Type": "application/json; charset=utf-8", From 2c25bfa8dfa34298f8dffc482012e305447a9a8e Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 20 May 2024 13:13:19 -0700 Subject: [PATCH 15/17] fix vertex ai import --- litellm/llms/vertex_httpx.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index e7b31b1554..59ded6be0c 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -157,9 +157,6 @@ class VertexLLM(BaseLLM): } \ "https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict" """ - - import vertexai - auth_header = self._ensure_access_token() request_data = { From 518db139820010a209394ddfb3ab0e1e6370f34a Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 20 May 2024 13:28:20 -0700 Subject: [PATCH 16/17] add parameter mapping with vertex ai --- docs/my-website/docs/providers/vertex.md | 13 +++++++++++++ litellm/llms/vertex_httpx.py | 10 ++++++++-- litellm/tests/test_image_generation.py | 1 + litellm/utils.py | 8 ++++++++ 4 files changed, 30 insertions(+), 2 deletions(-) diff --git a/docs/my-website/docs/providers/vertex.md b/docs/my-website/docs/providers/vertex.md index dc0ef48b48..32c3ea1881 100644 --- a/docs/my-website/docs/providers/vertex.md +++ b/docs/my-website/docs/providers/vertex.md @@ -521,6 +521,19 @@ response = await litellm.aimage_generation( ) ``` +**Generating multiple images** + +Use the `n` parameter to pass how many images you want generated +```python +response = await litellm.aimage_generation( + prompt="An olympic size swimming pool", + model="vertex_ai/imagegeneration@006", + vertex_ai_project="adroit-crow-413218", + vertex_ai_location="us-central1", + n=1, +) +``` + ## Extra ### Using `GOOGLE_APPLICATION_CREDENTIALS` diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index 59ded6be0c..35a6b1d473 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -153,15 +153,21 @@ class VertexLLM(BaseLLM): { "prompt": "a cat" } - ] + ], + "parameters": { + "sampleCount": 1 + } } \ "https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict" """ auth_header = self._ensure_access_token() + optional_params = optional_params or { + "sampleCount": 1 + } # default optional params request_data = { "instances": [{"prompt": prompt}], - "parameters": {"sampleCount": 1}, + "parameters": optional_params, } request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\"" diff --git a/litellm/tests/test_image_generation.py b/litellm/tests/test_image_generation.py index 9fe32544bd..35f66ad479 100644 --- a/litellm/tests/test_image_generation.py +++ b/litellm/tests/test_image_generation.py @@ -184,6 +184,7 @@ async def test_aimage_generation_vertex_ai(): model="vertex_ai/imagegeneration@006", vertex_ai_project="adroit-crow-413218", vertex_ai_location="us-central1", + n=1, ) assert response.data is not None assert len(response.data) > 0 diff --git a/litellm/utils.py b/litellm/utils.py index 3dac33e564..19f7c9910e 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4946,6 +4946,14 @@ def get_optional_params_image_gen( width, height = size.split("x") optional_params["width"] = int(width) optional_params["height"] = int(height) + elif custom_llm_provider == "vertex_ai": + supported_params = ["n"] + """ + All params here: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218 + """ + _check_valid_arg(supported_params=supported_params) + if n is not None: + optional_params["sampleCount"] = int(n) for k in passed_params.keys(): if k not in default_params.keys(): From f3eb8325932467e37db4da9288ebdf11d5f44f65 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 20 May 2024 13:43:54 -0700 Subject: [PATCH 17/17] fix vertex httpx client --- litellm/llms/vertex_httpx.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index 35a6b1d473..b8c698c901 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -135,6 +135,9 @@ class VertexLLM(BaseLLM): if isinstance(timeout, float) or isinstance(timeout, int): _httpx_timeout = httpx.Timeout(timeout) _params["timeout"] = _httpx_timeout + else: + _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) + self.async_handler = AsyncHTTPHandler(**_params) # type: ignore else: self.async_handler = client # type: ignore