diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py index 37434fbd47..bfd89a99fc 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py @@ -13,7 +13,6 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import httpx # type: ignore import requests # type: ignore -from openai.types.image import Image import litellm import litellm.litellm_core_utils @@ -1488,248 +1487,6 @@ class VertexLLM(BaseLLM): encoding=encoding, ) - def image_generation( - self, - prompt: str, - vertex_project: Optional[str], - vertex_location: Optional[str], - vertex_credentials: Optional[str], - model_response: litellm.ImageResponse, - model: Optional[ - str - ] = "imagegeneration", # vertex ai uses imagegeneration as the default model - client: Optional[Any] = None, - optional_params: Optional[dict] = None, - timeout: Optional[int] = None, - logging_obj=None, - aimg_generation=False, - ): - if aimg_generation is True: - return self.aimage_generation( - prompt=prompt, - vertex_project=vertex_project, - vertex_location=vertex_location, - vertex_credentials=vertex_credentials, - model=model, - client=client, - optional_params=optional_params, - timeout=timeout, - logging_obj=logging_obj, - model_response=model_response, - ) - - if client is None: - _params = {} - if timeout is not None: - if isinstance(timeout, float) or isinstance(timeout, int): - _httpx_timeout = httpx.Timeout(timeout) - _params["timeout"] = _httpx_timeout - else: - _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) - - sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore - else: - sync_handler = client # type: ignore - - url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict" - - auth_header, _ = self._ensure_access_token( - credentials=vertex_credentials, project_id=vertex_project - ) - optional_params = optional_params or { - "sampleCount": 1 - } # default optional params - - request_data = { - "instances": [{"prompt": prompt}], - "parameters": optional_params, - } - - request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\"" - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - - response = sync_handler.post( - url=url, - headers={ - "Content-Type": "application/json; charset=utf-8", - "Authorization": f"Bearer {auth_header}", - }, - data=json.dumps(request_data), - ) - - if response.status_code != 200: - raise Exception(f"Error: {response.status_code} {response.text}") - """ - Vertex AI Image generation response example: - { - "predictions": [ - { - "bytesBase64Encoded": "BASE64_IMG_BYTES", - "mimeType": "image/png" - }, - { - "mimeType": "image/png", - "bytesBase64Encoded": "BASE64_IMG_BYTES" - } - ] - } - """ - - _json_response = response.json() - if "predictions" not in _json_response: - raise litellm.InternalServerError( - message=f"image generation response does not contain 'predictions', got {_json_response}", - llm_provider="vertex_ai", - model=model, - ) - _predictions = _json_response["predictions"] - - _response_data: List[Image] = [] - for _prediction in _predictions: - _bytes_base64_encoded = _prediction["bytesBase64Encoded"] - image_object = Image(b64_json=_bytes_base64_encoded) - _response_data.append(image_object) - - model_response.data = _response_data - - return model_response - - async def aimage_generation( - self, - prompt: str, - vertex_project: Optional[str], - vertex_location: Optional[str], - vertex_credentials: Optional[str], - model_response: litellm.ImageResponse, - model: Optional[ - str - ] = "imagegeneration", # vertex ai uses imagegeneration as the default model - client: Optional[AsyncHTTPHandler] = None, - optional_params: Optional[dict] = None, - timeout: Optional[int] = None, - logging_obj=None, - ): - response = None - if client is None: - _params = {} - if timeout is not None: - if isinstance(timeout, float) or isinstance(timeout, int): - _httpx_timeout = httpx.Timeout(timeout) - _params["timeout"] = _httpx_timeout - else: - _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) - - self.async_handler = AsyncHTTPHandler(**_params) # type: ignore - else: - self.async_handler = client # type: ignore - - # make POST request to - # https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict - url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict" - - """ - Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218 - curl -X POST \ - -H "Authorization: Bearer $(gcloud auth print-access-token)" \ - -H "Content-Type: application/json; charset=utf-8" \ - -d { - "instances": [ - { - "prompt": "a cat" - } - ], - "parameters": { - "sampleCount": 1 - } - } \ - "https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict" - """ - auth_header, _ = self._ensure_access_token( - credentials=vertex_credentials, project_id=vertex_project - ) - optional_params = optional_params or { - "sampleCount": 1 - } # default optional params - - request_data = { - "instances": [{"prompt": prompt}], - "parameters": optional_params, - } - - request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\"" - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - - response = await self.async_handler.post( - url=url, - headers={ - "Content-Type": "application/json; charset=utf-8", - "Authorization": f"Bearer {auth_header}", - }, - data=json.dumps(request_data), - ) - - if response.status_code != 200: - raise Exception(f"Error: {response.status_code} {response.text}") - """ - Vertex AI Image generation response example: - { - "predictions": [ - { - "bytesBase64Encoded": "BASE64_IMG_BYTES", - "mimeType": "image/png" - }, - { - "mimeType": "image/png", - "bytesBase64Encoded": "BASE64_IMG_BYTES" - } - ] - } - """ - - _json_response = response.json() - - if "predictions" not in _json_response: - raise litellm.InternalServerError( - message=f"image generation response does not contain 'predictions', got {_json_response}", - llm_provider="vertex_ai", - model=model, - ) - - _predictions = _json_response["predictions"] - - _response_data: List[Image] = [] - for _prediction in _predictions: - _bytes_base64_encoded = _prediction["bytesBase64Encoded"] - image_object = Image(b64_json=_bytes_base64_encoded) - _response_data.append(image_object) - - model_response.data = _response_data - - return model_response - class ModelResponseIterator: def __init__(self, streaming_response, sync_stream: bool): diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/image_generation/image_generation_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/image_generation/image_generation_handler.py new file mode 100644 index 0000000000..dac4f08b69 --- /dev/null +++ b/litellm/llms/vertex_ai_and_google_ai_studio/image_generation/image_generation_handler.py @@ -0,0 +1,225 @@ +import json +from typing import Any, Dict, List, Optional + +import httpx +from openai.types.image import Image + +import litellm +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( + VertexLLM, +) + + +class VertexImageGeneration(VertexLLM): + def process_image_generation_response( + self, + json_response: Dict[str, Any], + model_response: litellm.ImageResponse, + model: str, + ) -> litellm.ImageResponse: + if "predictions" not in json_response: + raise litellm.InternalServerError( + message=f"image generation response does not contain 'predictions', got {json_response}", + llm_provider="vertex_ai", + model=model, + ) + + predictions = json_response["predictions"] + response_data: List[Image] = [] + + for prediction in predictions: + bytes_base64_encoded = prediction["bytesBase64Encoded"] + image_object = Image(b64_json=bytes_base64_encoded) + response_data.append(image_object) + + model_response.data = response_data + return model_response + + def image_generation( + self, + prompt: str, + vertex_project: Optional[str], + vertex_location: Optional[str], + vertex_credentials: Optional[str], + model_response: litellm.ImageResponse, + model: Optional[ + str + ] = "imagegeneration", # vertex ai uses imagegeneration as the default model + client: Optional[Any] = None, + optional_params: Optional[dict] = None, + timeout: Optional[int] = None, + logging_obj=None, + aimg_generation=False, + ): + if aimg_generation is True: + return self.aimage_generation( + prompt=prompt, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + model=model, + client=client, + optional_params=optional_params, + timeout=timeout, + logging_obj=logging_obj, + model_response=model_response, + ) + + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + _httpx_timeout = httpx.Timeout(timeout) + _params["timeout"] = _httpx_timeout + else: + _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) + + sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore + else: + sync_handler = client # type: ignore + + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict" + + auth_header, _ = self._ensure_access_token( + credentials=vertex_credentials, project_id=vertex_project + ) + optional_params = optional_params or { + "sampleCount": 1 + } # default optional params + + request_data = { + "instances": [{"prompt": prompt}], + "parameters": optional_params, + } + + request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\"" + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + + response = sync_handler.post( + url=url, + headers={ + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {auth_header}", + }, + data=json.dumps(request_data), + ) + + if response.status_code != 200: + raise Exception(f"Error: {response.status_code} {response.text}") + + json_response = response.json() + return self.process_image_generation_response( + json_response, model_response, model + ) + + async def aimage_generation( + self, + prompt: str, + vertex_project: Optional[str], + vertex_location: Optional[str], + vertex_credentials: Optional[str], + model_response: litellm.ImageResponse, + model: Optional[ + str + ] = "imagegeneration", # vertex ai uses imagegeneration as the default model + client: Optional[AsyncHTTPHandler] = None, + optional_params: Optional[dict] = None, + timeout: Optional[int] = None, + logging_obj=None, + ): + response = None + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + _httpx_timeout = httpx.Timeout(timeout) + _params["timeout"] = _httpx_timeout + else: + _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) + + self.async_handler = AsyncHTTPHandler(**_params) # type: ignore + else: + self.async_handler = client # type: ignore + + # make POST request to + # https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict" + + """ + Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218 + curl -X POST \ + -H "Authorization: Bearer $(gcloud auth print-access-token)" \ + -H "Content-Type: application/json; charset=utf-8" \ + -d { + "instances": [ + { + "prompt": "a cat" + } + ], + "parameters": { + "sampleCount": 1 + } + } \ + "https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict" + """ + auth_header, _ = self._ensure_access_token( + credentials=vertex_credentials, project_id=vertex_project + ) + optional_params = optional_params or { + "sampleCount": 1 + } # default optional params + + request_data = { + "instances": [{"prompt": prompt}], + "parameters": optional_params, + } + + request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\"" + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + + response = await self.async_handler.post( + url=url, + headers={ + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {auth_header}", + }, + data=json.dumps(request_data), + ) + + if response.status_code != 200: + raise Exception(f"Error: {response.status_code} {response.text}") + + json_response = response.json() + return self.process_image_generation_response( + json_response, model_response, model + ) + + def is_image_generation_response(self, json_response: Dict[str, Any]) -> bool: + if "predictions" in json_response: + if "bytesBase64Encoded" in json_response["predictions"][0]: + return True + return False diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index cfeebac8b3..fbefa867b5 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -844,7 +844,7 @@ async def _PROXY_track_cost_callback( kwargs["stream"] == True and "complete_streaming_response" in kwargs ): raise Exception( - f"Model not in litellm model cost map. Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing" + f"Model not in litellm model cost map. Passed model = {kwargs.get('model')} - Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing" ) except Exception as e: error_msg = f"error in tracking cost callback - {traceback.format_exc()}"