From eb9e4f178775db30bc840c5f9885e144e28e3002 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 2 Sep 2024 17:05:53 -0700 Subject: [PATCH 01/12] track /embedding in spendLogs --- .../vertex_embeddings/embedding_handler.py | 30 +++++++++++++++++++ .../pass_through_endpoints/success_handler.py | 24 +++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/embedding_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/embedding_handler.py index 4cd5513c4..5638c58cd 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/embedding_handler.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/embedding_handler.py @@ -281,3 +281,33 @@ async def async_embedding( ) setattr(model_response, "usage", usage) return model_response + + +async def transform_vertex_response_to_openai( + response: dict, model: str, model_response: litellm.EmbeddingResponse +) -> litellm.EmbeddingResponse: + + _predictions = response["predictions"] + + embedding_response = [] + input_tokens: int = 0 + for idx, element in enumerate(_predictions): + + embedding = element["embeddings"] + embedding_response.append( + { + "object": "embedding", + "index": idx, + "embedding": embedding["values"], + } + ) + input_tokens += embedding["statistics"]["token_count"] + + model_response.object = "list" + model_response.data = embedding_response + model_response.model = model + usage = Usage( + prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens + ) + setattr(model_response, "usage", usage) + return model_response diff --git a/litellm/proxy/pass_through_endpoints/success_handler.py b/litellm/proxy/pass_through_endpoints/success_handler.py index 618f68659..5ed6a1948 100644 --- a/litellm/proxy/pass_through_endpoints/success_handler.py +++ b/litellm/proxy/pass_through_endpoints/success_handler.py @@ -97,6 +97,30 @@ class PassThroughEndpointLogging: logging_obj.model = litellm_model_response.model logging_obj.model_call_details["model"] = logging_obj.model + await logging_obj.async_success_handler( + result=litellm_model_response, + start_time=start_time, + end_time=end_time, + cache_hit=cache_hit, + ) + elif "predict" in url_route: + from litellm.llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import ( + transform_vertex_response_to_openai, + ) + + model = self.extract_model_from_url(url_route) + _json_response = httpx_response.json() + + litellm_model_response = await transform_vertex_response_to_openai( + response=_json_response, + model=model, + model_response=litellm.EmbeddingResponse(), + ) + + litellm_model_response.model = model + logging_obj.model = litellm_model_response.model + logging_obj.model_call_details["model"] = logging_obj.model + await logging_obj.async_success_handler( result=litellm_model_response, start_time=start_time, From 296dba89231fd30d8257c9751ad7a6fa15ed0d0e Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 2 Sep 2024 17:08:30 -0700 Subject: [PATCH 02/12] fix lining --- litellm/proxy/pass_through_endpoints/streaming_handler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/streaming_handler.py b/litellm/proxy/pass_through_endpoints/streaming_handler.py index ab1d5d813..4420bd1d7 100644 --- a/litellm/proxy/pass_through_endpoints/streaming_handler.py +++ b/litellm/proxy/pass_through_endpoints/streaming_handler.py @@ -95,9 +95,9 @@ async def chunk_processor( except json.JSONDecodeError: pass - complete_streaming_response: litellm.ModelResponse = ( - litellm.stream_chunk_builder(chunks=all_chunks) - ) + complete_streaming_response: Optional[ + Union[litellm.ModelResponse, litellm.TextCompletionResponse] + ] = litellm.stream_chunk_builder(chunks=all_chunks) end_time = datetime.now() if passthrough_success_handler_obj.is_vertex_route(url_route): From 6393d2391ee2bc587ac41155b981bcb5d5375d97 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 2 Sep 2024 17:34:43 -0700 Subject: [PATCH 03/12] refactor vertex to use spearate image gen folder --- litellm/main.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/litellm/main.py b/litellm/main.py index 70cd40f31..a62703a50 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -126,6 +126,9 @@ from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gem from .llms.vertex_ai_and_google_ai_studio.gemini_embeddings.batch_embed_content_handler import ( GoogleBatchEmbeddings, ) +from .llms.vertex_ai_and_google_ai_studio.image_generation.image_generation_handler import ( + VertexImageGeneration, +) from .llms.vertex_ai_and_google_ai_studio.multimodal_embeddings.embedding_handler import ( VertexMultimodalEmbedding, ) @@ -180,6 +183,7 @@ bedrock_converse_chat_completion = BedrockConverseLLM() bedrock_embedding = BedrockEmbedding() vertex_chat_completion = VertexLLM() vertex_multimodal_embedding = VertexMultimodalEmbedding() +vertex_image_generation = VertexImageGeneration() google_batch_embeddings = GoogleBatchEmbeddings() vertex_partner_models_chat_completion = VertexAIPartnerModels() vertex_text_to_speech = VertexTextToSpeechAPI() @@ -4534,7 +4538,7 @@ def image_generation( or optional_params.pop("vertex_ai_credentials", None) or get_secret("VERTEXAI_CREDENTIALS") ) - model_response = vertex_chat_completion.image_generation( + model_response = vertex_image_generation.image_generation( model=model, prompt=prompt, timeout=timeout, From aa13977136ed323e646975cce5345bff276ad31c Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 2 Sep 2024 17:35:51 -0700 Subject: [PATCH 04/12] refactor vtx image gen --- .../vertex_and_google_ai_studio_gemini.py | 243 ------------------ .../image_generation_handler.py | 225 ++++++++++++++++ litellm/proxy/proxy_server.py | 2 +- 3 files changed, 226 insertions(+), 244 deletions(-) create mode 100644 litellm/llms/vertex_ai_and_google_ai_studio/image_generation/image_generation_handler.py diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py index 37434fbd4..bfd89a99f 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py @@ -13,7 +13,6 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import httpx # type: ignore import requests # type: ignore -from openai.types.image import Image import litellm import litellm.litellm_core_utils @@ -1488,248 +1487,6 @@ class VertexLLM(BaseLLM): encoding=encoding, ) - def image_generation( - self, - prompt: str, - vertex_project: Optional[str], - vertex_location: Optional[str], - vertex_credentials: Optional[str], - model_response: litellm.ImageResponse, - model: Optional[ - str - ] = "imagegeneration", # vertex ai uses imagegeneration as the default model - client: Optional[Any] = None, - optional_params: Optional[dict] = None, - timeout: Optional[int] = None, - logging_obj=None, - aimg_generation=False, - ): - if aimg_generation is True: - return self.aimage_generation( - prompt=prompt, - vertex_project=vertex_project, - vertex_location=vertex_location, - vertex_credentials=vertex_credentials, - model=model, - client=client, - optional_params=optional_params, - timeout=timeout, - logging_obj=logging_obj, - model_response=model_response, - ) - - if client is None: - _params = {} - if timeout is not None: - if isinstance(timeout, float) or isinstance(timeout, int): - _httpx_timeout = httpx.Timeout(timeout) - _params["timeout"] = _httpx_timeout - else: - _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) - - sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore - else: - sync_handler = client # type: ignore - - url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict" - - auth_header, _ = self._ensure_access_token( - credentials=vertex_credentials, project_id=vertex_project - ) - optional_params = optional_params or { - "sampleCount": 1 - } # default optional params - - request_data = { - "instances": [{"prompt": prompt}], - "parameters": optional_params, - } - - request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\"" - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - - response = sync_handler.post( - url=url, - headers={ - "Content-Type": "application/json; charset=utf-8", - "Authorization": f"Bearer {auth_header}", - }, - data=json.dumps(request_data), - ) - - if response.status_code != 200: - raise Exception(f"Error: {response.status_code} {response.text}") - """ - Vertex AI Image generation response example: - { - "predictions": [ - { - "bytesBase64Encoded": "BASE64_IMG_BYTES", - "mimeType": "image/png" - }, - { - "mimeType": "image/png", - "bytesBase64Encoded": "BASE64_IMG_BYTES" - } - ] - } - """ - - _json_response = response.json() - if "predictions" not in _json_response: - raise litellm.InternalServerError( - message=f"image generation response does not contain 'predictions', got {_json_response}", - llm_provider="vertex_ai", - model=model, - ) - _predictions = _json_response["predictions"] - - _response_data: List[Image] = [] - for _prediction in _predictions: - _bytes_base64_encoded = _prediction["bytesBase64Encoded"] - image_object = Image(b64_json=_bytes_base64_encoded) - _response_data.append(image_object) - - model_response.data = _response_data - - return model_response - - async def aimage_generation( - self, - prompt: str, - vertex_project: Optional[str], - vertex_location: Optional[str], - vertex_credentials: Optional[str], - model_response: litellm.ImageResponse, - model: Optional[ - str - ] = "imagegeneration", # vertex ai uses imagegeneration as the default model - client: Optional[AsyncHTTPHandler] = None, - optional_params: Optional[dict] = None, - timeout: Optional[int] = None, - logging_obj=None, - ): - response = None - if client is None: - _params = {} - if timeout is not None: - if isinstance(timeout, float) or isinstance(timeout, int): - _httpx_timeout = httpx.Timeout(timeout) - _params["timeout"] = _httpx_timeout - else: - _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) - - self.async_handler = AsyncHTTPHandler(**_params) # type: ignore - else: - self.async_handler = client # type: ignore - - # make POST request to - # https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict - url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict" - - """ - Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218 - curl -X POST \ - -H "Authorization: Bearer $(gcloud auth print-access-token)" \ - -H "Content-Type: application/json; charset=utf-8" \ - -d { - "instances": [ - { - "prompt": "a cat" - } - ], - "parameters": { - "sampleCount": 1 - } - } \ - "https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict" - """ - auth_header, _ = self._ensure_access_token( - credentials=vertex_credentials, project_id=vertex_project - ) - optional_params = optional_params or { - "sampleCount": 1 - } # default optional params - - request_data = { - "instances": [{"prompt": prompt}], - "parameters": optional_params, - } - - request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\"" - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - - response = await self.async_handler.post( - url=url, - headers={ - "Content-Type": "application/json; charset=utf-8", - "Authorization": f"Bearer {auth_header}", - }, - data=json.dumps(request_data), - ) - - if response.status_code != 200: - raise Exception(f"Error: {response.status_code} {response.text}") - """ - Vertex AI Image generation response example: - { - "predictions": [ - { - "bytesBase64Encoded": "BASE64_IMG_BYTES", - "mimeType": "image/png" - }, - { - "mimeType": "image/png", - "bytesBase64Encoded": "BASE64_IMG_BYTES" - } - ] - } - """ - - _json_response = response.json() - - if "predictions" not in _json_response: - raise litellm.InternalServerError( - message=f"image generation response does not contain 'predictions', got {_json_response}", - llm_provider="vertex_ai", - model=model, - ) - - _predictions = _json_response["predictions"] - - _response_data: List[Image] = [] - for _prediction in _predictions: - _bytes_base64_encoded = _prediction["bytesBase64Encoded"] - image_object = Image(b64_json=_bytes_base64_encoded) - _response_data.append(image_object) - - model_response.data = _response_data - - return model_response - class ModelResponseIterator: def __init__(self, streaming_response, sync_stream: bool): diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/image_generation/image_generation_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/image_generation/image_generation_handler.py new file mode 100644 index 000000000..dac4f08b6 --- /dev/null +++ b/litellm/llms/vertex_ai_and_google_ai_studio/image_generation/image_generation_handler.py @@ -0,0 +1,225 @@ +import json +from typing import Any, Dict, List, Optional + +import httpx +from openai.types.image import Image + +import litellm +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( + VertexLLM, +) + + +class VertexImageGeneration(VertexLLM): + def process_image_generation_response( + self, + json_response: Dict[str, Any], + model_response: litellm.ImageResponse, + model: str, + ) -> litellm.ImageResponse: + if "predictions" not in json_response: + raise litellm.InternalServerError( + message=f"image generation response does not contain 'predictions', got {json_response}", + llm_provider="vertex_ai", + model=model, + ) + + predictions = json_response["predictions"] + response_data: List[Image] = [] + + for prediction in predictions: + bytes_base64_encoded = prediction["bytesBase64Encoded"] + image_object = Image(b64_json=bytes_base64_encoded) + response_data.append(image_object) + + model_response.data = response_data + return model_response + + def image_generation( + self, + prompt: str, + vertex_project: Optional[str], + vertex_location: Optional[str], + vertex_credentials: Optional[str], + model_response: litellm.ImageResponse, + model: Optional[ + str + ] = "imagegeneration", # vertex ai uses imagegeneration as the default model + client: Optional[Any] = None, + optional_params: Optional[dict] = None, + timeout: Optional[int] = None, + logging_obj=None, + aimg_generation=False, + ): + if aimg_generation is True: + return self.aimage_generation( + prompt=prompt, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + model=model, + client=client, + optional_params=optional_params, + timeout=timeout, + logging_obj=logging_obj, + model_response=model_response, + ) + + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + _httpx_timeout = httpx.Timeout(timeout) + _params["timeout"] = _httpx_timeout + else: + _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) + + sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore + else: + sync_handler = client # type: ignore + + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict" + + auth_header, _ = self._ensure_access_token( + credentials=vertex_credentials, project_id=vertex_project + ) + optional_params = optional_params or { + "sampleCount": 1 + } # default optional params + + request_data = { + "instances": [{"prompt": prompt}], + "parameters": optional_params, + } + + request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\"" + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + + response = sync_handler.post( + url=url, + headers={ + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {auth_header}", + }, + data=json.dumps(request_data), + ) + + if response.status_code != 200: + raise Exception(f"Error: {response.status_code} {response.text}") + + json_response = response.json() + return self.process_image_generation_response( + json_response, model_response, model + ) + + async def aimage_generation( + self, + prompt: str, + vertex_project: Optional[str], + vertex_location: Optional[str], + vertex_credentials: Optional[str], + model_response: litellm.ImageResponse, + model: Optional[ + str + ] = "imagegeneration", # vertex ai uses imagegeneration as the default model + client: Optional[AsyncHTTPHandler] = None, + optional_params: Optional[dict] = None, + timeout: Optional[int] = None, + logging_obj=None, + ): + response = None + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + _httpx_timeout = httpx.Timeout(timeout) + _params["timeout"] = _httpx_timeout + else: + _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) + + self.async_handler = AsyncHTTPHandler(**_params) # type: ignore + else: + self.async_handler = client # type: ignore + + # make POST request to + # https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict" + + """ + Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218 + curl -X POST \ + -H "Authorization: Bearer $(gcloud auth print-access-token)" \ + -H "Content-Type: application/json; charset=utf-8" \ + -d { + "instances": [ + { + "prompt": "a cat" + } + ], + "parameters": { + "sampleCount": 1 + } + } \ + "https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict" + """ + auth_header, _ = self._ensure_access_token( + credentials=vertex_credentials, project_id=vertex_project + ) + optional_params = optional_params or { + "sampleCount": 1 + } # default optional params + + request_data = { + "instances": [{"prompt": prompt}], + "parameters": optional_params, + } + + request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\"" + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + + response = await self.async_handler.post( + url=url, + headers={ + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {auth_header}", + }, + data=json.dumps(request_data), + ) + + if response.status_code != 200: + raise Exception(f"Error: {response.status_code} {response.text}") + + json_response = response.json() + return self.process_image_generation_response( + json_response, model_response, model + ) + + def is_image_generation_response(self, json_response: Dict[str, Any]) -> bool: + if "predictions" in json_response: + if "bytesBase64Encoded" in json_response["predictions"][0]: + return True + return False diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index cfeebac8b..fbefa867b 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -844,7 +844,7 @@ async def _PROXY_track_cost_callback( kwargs["stream"] == True and "complete_streaming_response" in kwargs ): raise Exception( - f"Model not in litellm model cost map. Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing" + f"Model not in litellm model cost map. Passed model = {kwargs.get('model')} - Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing" ) except Exception as e: error_msg = f"error in tracking cost callback - {traceback.format_exc()}" From 811aa34a36388fee37f50765ac79cbc6751fe5ff Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 2 Sep 2024 17:36:25 -0700 Subject: [PATCH 05/12] track image gen in spend logs --- .../pass_through_endpoints/success_handler.py | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/success_handler.py b/litellm/proxy/pass_through_endpoints/success_handler.py index 5ed6a1948..39f1d14ab 100644 --- a/litellm/proxy/pass_through_endpoints/success_handler.py +++ b/litellm/proxy/pass_through_endpoints/success_handler.py @@ -104,18 +104,35 @@ class PassThroughEndpointLogging: cache_hit=cache_hit, ) elif "predict" in url_route: + from litellm.llms.vertex_ai_and_google_ai_studio.image_generation.image_generation_handler import ( + VertexImageGeneration, + ) from litellm.llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import ( transform_vertex_response_to_openai, ) + vertex_image_generation_class = VertexImageGeneration() + model = self.extract_model_from_url(url_route) _json_response = httpx_response.json() - litellm_model_response = await transform_vertex_response_to_openai( - response=_json_response, - model=model, - model_response=litellm.EmbeddingResponse(), - ) + litellm_model_response = litellm.ModelResponse() + if vertex_image_generation_class.is_image_generation_response( + _json_response + ): + litellm_model_response = ( + vertex_image_generation_class.process_image_generation_response( + _json_response, + model_response=litellm.ImageResponse(), + model=model, + ) + ) + else: + litellm_model_response = await transform_vertex_response_to_openai( + response=_json_response, + model=model, + model_response=litellm.EmbeddingResponse(), + ) litellm_model_response.model = model logging_obj.model = litellm_model_response.model From 80dd2cfc7f1bb4d5e239bdc36fe88138e6ac839e Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 2 Sep 2024 17:47:29 -0700 Subject: [PATCH 06/12] fix get_llm_provider for imagegeneration@006 --- litellm/__init__.py | 4 ++++ litellm/tests/test_get_llm_provider.py | 7 +++++++ 2 files changed, 11 insertions(+) diff --git a/litellm/__init__.py b/litellm/__init__.py index 3f22e41b6..4148fdc78 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -355,6 +355,7 @@ vertex_language_models: List = [] vertex_vision_models: List = [] vertex_chat_models: List = [] vertex_code_chat_models: List = [] +vertex_ai_image_models: List = [] vertex_text_models: List = [] vertex_code_text_models: List = [] vertex_embedding_models: List = [] @@ -414,6 +415,9 @@ for key, value in model_cost.items(): elif value.get("litellm_provider") == "vertex_ai-ai21_models": key = key.replace("vertex_ai/", "") vertex_ai_ai21_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-image-models": + key = key.replace("vertex_ai/", "") + vertex_ai_image_models.append(key) elif value.get("litellm_provider") == "ai21": ai21_models.append(key) elif value.get("litellm_provider") == "nlp_cloud": diff --git a/litellm/tests/test_get_llm_provider.py b/litellm/tests/test_get_llm_provider.py index 5e1c1f4fe..4eef036a7 100644 --- a/litellm/tests/test_get_llm_provider.py +++ b/litellm/tests/test_get_llm_provider.py @@ -68,3 +68,10 @@ def test_get_llm_provider_deepseek_custom_api_base(): assert api_base == "MY-FAKE-BASE" os.environ.pop("DEEPSEEK_API_BASE") + + +def test_get_llm_provider_vertex_ai_image_models(): + model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider( + model="imagegeneration@006", + ) + assert custom_llm_provider == "vertex_ai" From 9fcab392a4a536c7d433b8c929057ac0e9ecf719 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 2 Sep 2024 17:49:36 -0700 Subject: [PATCH 07/12] fix get llm provider for imagen --- litellm/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litellm/utils.py b/litellm/utils.py index efd48e8ab..bb16c9f0a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4930,6 +4930,7 @@ def get_llm_provider( or model in litellm.vertex_language_models or model in litellm.vertex_embedding_models or model in litellm.vertex_vision_models + or model in litellm.vertex_ai_image_models ): custom_llm_provider = "vertex_ai" ## ai21 From 4a0fdc40f15f2ab60b749169c3f6e7e64a0c91ea Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 2 Sep 2024 18:10:46 -0700 Subject: [PATCH 08/12] add cost tracking for pass through imagen --- litellm/cost_calculator.py | 3 ++- litellm/proxy/pass_through_endpoints/success_handler.py | 5 +++++ litellm/tests/test_get_llm_provider.py | 2 +- litellm/types/utils.py | 4 ++++ 4 files changed, 12 insertions(+), 2 deletions(-) diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index a0645c19a..a66a80002 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -24,7 +24,7 @@ from litellm.llms.anthropic.cost_calculation import ( ) from litellm.types.llms.openai import HttpxBinaryResponseContent from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS -from litellm.types.utils import Usage +from litellm.types.utils import PassthroughCallTypes, Usage from litellm.utils import ( CallTypes, CostPerToken, @@ -625,6 +625,7 @@ def completion_cost( if ( call_type == CallTypes.image_generation.value or call_type == CallTypes.aimage_generation.value + or call_type == PassthroughCallTypes.passthrough_image_generation.value ): ### IMAGE GENERATION COST CALCULATION ### if custom_llm_provider == "vertex_ai": diff --git a/litellm/proxy/pass_through_endpoints/success_handler.py b/litellm/proxy/pass_through_endpoints/success_handler.py index 39f1d14ab..f29129df1 100644 --- a/litellm/proxy/pass_through_endpoints/success_handler.py +++ b/litellm/proxy/pass_through_endpoints/success_handler.py @@ -110,6 +110,7 @@ class PassThroughEndpointLogging: from litellm.llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import ( transform_vertex_response_to_openai, ) + from litellm.types.utils import PassthroughCallTypes vertex_image_generation_class = VertexImageGeneration() @@ -127,6 +128,10 @@ class PassThroughEndpointLogging: model=model, ) ) + + logging_obj.call_type = ( + PassthroughCallTypes.passthrough_image_generation.value + ) else: litellm_model_response = await transform_vertex_response_to_openai( response=_json_response, diff --git a/litellm/tests/test_get_llm_provider.py b/litellm/tests/test_get_llm_provider.py index 4eef036a7..8f585a072 100644 --- a/litellm/tests/test_get_llm_provider.py +++ b/litellm/tests/test_get_llm_provider.py @@ -72,6 +72,6 @@ def test_get_llm_provider_deepseek_custom_api_base(): def test_get_llm_provider_vertex_ai_image_models(): model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider( - model="imagegeneration@006", + model="imagegeneration@006", custom_llm_provider=None ) assert custom_llm_provider == "vertex_ai" diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 9e8c7be34..d649a30f0 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -119,6 +119,10 @@ class CallTypes(Enum): speech = "speech" +class PassthroughCallTypes(Enum): + passthrough_image_generation = "passthrough-image-generation" + + class TopLogprob(OpenAIObject): token: str """The token.""" From ad10dcd3c30677f9cdac48835702b022dd81f5c1 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 2 Sep 2024 18:14:15 -0700 Subject: [PATCH 09/12] fix linting error --- litellm/proxy/pass_through_endpoints/streaming_handler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/litellm/proxy/pass_through_endpoints/streaming_handler.py b/litellm/proxy/pass_through_endpoints/streaming_handler.py index 4420bd1d7..b7faa21e4 100644 --- a/litellm/proxy/pass_through_endpoints/streaming_handler.py +++ b/litellm/proxy/pass_through_endpoints/streaming_handler.py @@ -98,6 +98,8 @@ async def chunk_processor( complete_streaming_response: Optional[ Union[litellm.ModelResponse, litellm.TextCompletionResponse] ] = litellm.stream_chunk_builder(chunks=all_chunks) + if complete_streaming_response is None: + complete_streaming_response = litellm.ModelResponse() end_time = datetime.now() if passthrough_success_handler_obj.is_vertex_route(url_route): From b0735c9e9f9d8cdf65c51315842237b9079212f0 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 2 Sep 2024 18:17:12 -0700 Subject: [PATCH 10/12] add doc with support imagen models --- docs/my-website/docs/providers/vertex.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/my-website/docs/providers/vertex.md b/docs/my-website/docs/providers/vertex.md index 868c8602c..ac134d009 100644 --- a/docs/my-website/docs/providers/vertex.md +++ b/docs/my-website/docs/providers/vertex.md @@ -1833,6 +1833,19 @@ response = await litellm.aimage_generation( ) ``` +### Supported Image Generation Models + +| Model Name | FUsage | +|------------------------------|--------------------------------------------------------------| +| `imagen-3.0-generate-001` | `litellm.image_generation('vertex_ai/imagen-3.0-generate-001', prompt)` | +| `imagen-3.0-fast-generate-001` | `litellm.image_generation('vertex_ai/imagen-3.0-fast-generate-001', prompt)` | +| `imagegeneration@006` | `litellm.image_generation('vertex_ai/imagegeneration@006', prompt)` | +| `imagegeneration@005` | `litellm.image_generation('vertex_ai/imagegeneration@005', prompt)` | +| `imagegeneration@002` | `litellm.image_generation('vertex_ai/imagegeneration@002', prompt)` | + + + + ## **Text to Speech APIs** :::info From bfb39eb0cdec7e172343ba311e6b5d707b3a849d Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 2 Sep 2024 19:39:10 -0700 Subject: [PATCH 11/12] fix linting errors --- .../image_generation/image_generation_handler.py | 2 +- litellm/proxy/pass_through_endpoints/success_handler.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/image_generation/image_generation_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/image_generation/image_generation_handler.py index dac4f08b6..440d0841a 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/image_generation/image_generation_handler.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/image_generation/image_generation_handler.py @@ -16,7 +16,7 @@ class VertexImageGeneration(VertexLLM): self, json_response: Dict[str, Any], model_response: litellm.ImageResponse, - model: str, + model: Optional[str] = None, ) -> litellm.ImageResponse: if "predictions" not in json_response: raise litellm.InternalServerError( diff --git a/litellm/proxy/pass_through_endpoints/success_handler.py b/litellm/proxy/pass_through_endpoints/success_handler.py index f29129df1..5d315ae3d 100644 --- a/litellm/proxy/pass_through_endpoints/success_handler.py +++ b/litellm/proxy/pass_through_endpoints/success_handler.py @@ -1,5 +1,6 @@ import re from datetime import datetime +from typing import Union import httpx @@ -117,7 +118,9 @@ class PassThroughEndpointLogging: model = self.extract_model_from_url(url_route) _json_response = httpx_response.json() - litellm_model_response = litellm.ModelResponse() + litellm_model_response: Union[ + litellm.ModelResponse, litellm.EmbeddingResponse, litellm.ImageResponse + ] = litellm.ModelResponse() if vertex_image_generation_class.is_image_generation_response( _json_response ): From 35c0c07b249748a2c092c314054d81d4afaae231 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 2 Sep 2024 19:42:36 -0700 Subject: [PATCH 12/12] fix success handler typing --- .../proxy/pass_through_endpoints/success_handler.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/success_handler.py b/litellm/proxy/pass_through_endpoints/success_handler.py index 5d315ae3d..b3b2e94bf 100644 --- a/litellm/proxy/pass_through_endpoints/success_handler.py +++ b/litellm/proxy/pass_through_endpoints/success_handler.py @@ -118,13 +118,13 @@ class PassThroughEndpointLogging: model = self.extract_model_from_url(url_route) _json_response = httpx_response.json() - litellm_model_response: Union[ + litellm_prediction_response: Union[ litellm.ModelResponse, litellm.EmbeddingResponse, litellm.ImageResponse ] = litellm.ModelResponse() if vertex_image_generation_class.is_image_generation_response( _json_response ): - litellm_model_response = ( + litellm_prediction_response = ( vertex_image_generation_class.process_image_generation_response( _json_response, model_response=litellm.ImageResponse(), @@ -136,18 +136,19 @@ class PassThroughEndpointLogging: PassthroughCallTypes.passthrough_image_generation.value ) else: - litellm_model_response = await transform_vertex_response_to_openai( + litellm_prediction_response = await transform_vertex_response_to_openai( response=_json_response, model=model, model_response=litellm.EmbeddingResponse(), ) + if isinstance(litellm_prediction_response, litellm.EmbeddingResponse): + litellm_prediction_response.model = model - litellm_model_response.model = model - logging_obj.model = litellm_model_response.model + logging_obj.model = model logging_obj.model_call_details["model"] = logging_obj.model await logging_obj.async_success_handler( - result=litellm_model_response, + result=litellm_prediction_response, start_time=start_time, end_time=end_time, cache_hit=cache_hit,