Merge pull request #5481 from BerriAI/litellm_track_imagen_spend_logs

[Feat-Proxy] track imagen /predict in LiteLLM spend logs
This commit is contained in:
Ishaan Jaff 2024-09-02 21:21:21 -07:00 committed by GitHub
commit 7aaabb51ef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 342 additions and 246 deletions

View file

@ -1833,6 +1833,19 @@ response = await litellm.aimage_generation(
) )
``` ```
### Supported Image Generation Models
| Model Name | FUsage |
|------------------------------|--------------------------------------------------------------|
| `imagen-3.0-generate-001` | `litellm.image_generation('vertex_ai/imagen-3.0-generate-001', prompt)` |
| `imagen-3.0-fast-generate-001` | `litellm.image_generation('vertex_ai/imagen-3.0-fast-generate-001', prompt)` |
| `imagegeneration@006` | `litellm.image_generation('vertex_ai/imagegeneration@006', prompt)` |
| `imagegeneration@005` | `litellm.image_generation('vertex_ai/imagegeneration@005', prompt)` |
| `imagegeneration@002` | `litellm.image_generation('vertex_ai/imagegeneration@002', prompt)` |
## **Text to Speech APIs** ## **Text to Speech APIs**
:::info :::info

View file

@ -356,6 +356,7 @@ vertex_language_models: List = []
vertex_vision_models: List = [] vertex_vision_models: List = []
vertex_chat_models: List = [] vertex_chat_models: List = []
vertex_code_chat_models: List = [] vertex_code_chat_models: List = []
vertex_ai_image_models: List = []
vertex_text_models: List = [] vertex_text_models: List = []
vertex_code_text_models: List = [] vertex_code_text_models: List = []
vertex_embedding_models: List = [] vertex_embedding_models: List = []
@ -416,6 +417,9 @@ for key, value in model_cost.items():
elif value.get("litellm_provider") == "vertex_ai-ai21_models": elif value.get("litellm_provider") == "vertex_ai-ai21_models":
key = key.replace("vertex_ai/", "") key = key.replace("vertex_ai/", "")
vertex_ai_ai21_models.append(key) vertex_ai_ai21_models.append(key)
elif value.get("litellm_provider") == "vertex_ai-image-models":
key = key.replace("vertex_ai/", "")
vertex_ai_image_models.append(key)
elif value.get("litellm_provider") == "ai21": elif value.get("litellm_provider") == "ai21":
if value.get("mode") == "chat": if value.get("mode") == "chat":
ai21_chat_models.append(key) ai21_chat_models.append(key)

View file

@ -24,7 +24,7 @@ from litellm.llms.anthropic.cost_calculation import (
) )
from litellm.types.llms.openai import HttpxBinaryResponseContent from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
from litellm.types.utils import Usage from litellm.types.utils import PassthroughCallTypes, Usage
from litellm.utils import ( from litellm.utils import (
CallTypes, CallTypes,
CostPerToken, CostPerToken,
@ -625,6 +625,7 @@ def completion_cost(
if ( if (
call_type == CallTypes.image_generation.value call_type == CallTypes.image_generation.value
or call_type == CallTypes.aimage_generation.value or call_type == CallTypes.aimage_generation.value
or call_type == PassthroughCallTypes.passthrough_image_generation.value
): ):
### IMAGE GENERATION COST CALCULATION ### ### IMAGE GENERATION COST CALCULATION ###
if custom_llm_provider == "vertex_ai": if custom_llm_provider == "vertex_ai":

View file

@ -13,7 +13,6 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore import requests # type: ignore
from openai.types.image import Image
import litellm import litellm
import litellm.litellm_core_utils import litellm.litellm_core_utils
@ -1488,248 +1487,6 @@ class VertexLLM(BaseLLM):
encoding=encoding, encoding=encoding,
) )
def image_generation(
self,
prompt: str,
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
model_response: litellm.ImageResponse,
model: Optional[
str
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
client: Optional[Any] = None,
optional_params: Optional[dict] = None,
timeout: Optional[int] = None,
logging_obj=None,
aimg_generation=False,
):
if aimg_generation is True:
return self.aimage_generation(
prompt=prompt,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
model=model,
client=client,
optional_params=optional_params,
timeout=timeout,
logging_obj=logging_obj,
model_response=model_response,
)
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
else:
sync_handler = client # type: ignore
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
auth_header, _ = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
)
optional_params = optional_params or {
"sampleCount": 1
} # default optional params
request_data = {
"instances": [{"prompt": prompt}],
"parameters": optional_params,
}
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = sync_handler.post(
url=url,
headers={
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}",
},
data=json.dumps(request_data),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
"""
Vertex AI Image generation response example:
{
"predictions": [
{
"bytesBase64Encoded": "BASE64_IMG_BYTES",
"mimeType": "image/png"
},
{
"mimeType": "image/png",
"bytesBase64Encoded": "BASE64_IMG_BYTES"
}
]
}
"""
_json_response = response.json()
if "predictions" not in _json_response:
raise litellm.InternalServerError(
message=f"image generation response does not contain 'predictions', got {_json_response}",
llm_provider="vertex_ai",
model=model,
)
_predictions = _json_response["predictions"]
_response_data: List[Image] = []
for _prediction in _predictions:
_bytes_base64_encoded = _prediction["bytesBase64Encoded"]
image_object = Image(b64_json=_bytes_base64_encoded)
_response_data.append(image_object)
model_response.data = _response_data
return model_response
async def aimage_generation(
self,
prompt: str,
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
model_response: litellm.ImageResponse,
model: Optional[
str
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
client: Optional[AsyncHTTPHandler] = None,
optional_params: Optional[dict] = None,
timeout: Optional[int] = None,
logging_obj=None,
):
response = None
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
self.async_handler = AsyncHTTPHandler(**_params) # type: ignore
else:
self.async_handler = client # type: ignore
# make POST request to
# https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
"""
Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
curl -X POST \
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
-H "Content-Type: application/json; charset=utf-8" \
-d {
"instances": [
{
"prompt": "a cat"
}
],
"parameters": {
"sampleCount": 1
}
} \
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
"""
auth_header, _ = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
)
optional_params = optional_params or {
"sampleCount": 1
} # default optional params
request_data = {
"instances": [{"prompt": prompt}],
"parameters": optional_params,
}
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = await self.async_handler.post(
url=url,
headers={
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}",
},
data=json.dumps(request_data),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
"""
Vertex AI Image generation response example:
{
"predictions": [
{
"bytesBase64Encoded": "BASE64_IMG_BYTES",
"mimeType": "image/png"
},
{
"mimeType": "image/png",
"bytesBase64Encoded": "BASE64_IMG_BYTES"
}
]
}
"""
_json_response = response.json()
if "predictions" not in _json_response:
raise litellm.InternalServerError(
message=f"image generation response does not contain 'predictions', got {_json_response}",
llm_provider="vertex_ai",
model=model,
)
_predictions = _json_response["predictions"]
_response_data: List[Image] = []
for _prediction in _predictions:
_bytes_base64_encoded = _prediction["bytesBase64Encoded"]
image_object = Image(b64_json=_bytes_base64_encoded)
_response_data.append(image_object)
model_response.data = _response_data
return model_response
class ModelResponseIterator: class ModelResponseIterator:
def __init__(self, streaming_response, sync_stream: bool): def __init__(self, streaming_response, sync_stream: bool):

View file

@ -0,0 +1,225 @@
import json
from typing import Any, Dict, List, Optional
import httpx
from openai.types.image import Image
import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
class VertexImageGeneration(VertexLLM):
def process_image_generation_response(
self,
json_response: Dict[str, Any],
model_response: litellm.ImageResponse,
model: Optional[str] = None,
) -> litellm.ImageResponse:
if "predictions" not in json_response:
raise litellm.InternalServerError(
message=f"image generation response does not contain 'predictions', got {json_response}",
llm_provider="vertex_ai",
model=model,
)
predictions = json_response["predictions"]
response_data: List[Image] = []
for prediction in predictions:
bytes_base64_encoded = prediction["bytesBase64Encoded"]
image_object = Image(b64_json=bytes_base64_encoded)
response_data.append(image_object)
model_response.data = response_data
return model_response
def image_generation(
self,
prompt: str,
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
model_response: litellm.ImageResponse,
model: Optional[
str
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
client: Optional[Any] = None,
optional_params: Optional[dict] = None,
timeout: Optional[int] = None,
logging_obj=None,
aimg_generation=False,
):
if aimg_generation is True:
return self.aimage_generation(
prompt=prompt,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
model=model,
client=client,
optional_params=optional_params,
timeout=timeout,
logging_obj=logging_obj,
model_response=model_response,
)
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
else:
sync_handler = client # type: ignore
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
auth_header, _ = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
)
optional_params = optional_params or {
"sampleCount": 1
} # default optional params
request_data = {
"instances": [{"prompt": prompt}],
"parameters": optional_params,
}
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = sync_handler.post(
url=url,
headers={
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}",
},
data=json.dumps(request_data),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
json_response = response.json()
return self.process_image_generation_response(
json_response, model_response, model
)
async def aimage_generation(
self,
prompt: str,
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
model_response: litellm.ImageResponse,
model: Optional[
str
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
client: Optional[AsyncHTTPHandler] = None,
optional_params: Optional[dict] = None,
timeout: Optional[int] = None,
logging_obj=None,
):
response = None
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
self.async_handler = AsyncHTTPHandler(**_params) # type: ignore
else:
self.async_handler = client # type: ignore
# make POST request to
# https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
"""
Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
curl -X POST \
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
-H "Content-Type: application/json; charset=utf-8" \
-d {
"instances": [
{
"prompt": "a cat"
}
],
"parameters": {
"sampleCount": 1
}
} \
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
"""
auth_header, _ = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
)
optional_params = optional_params or {
"sampleCount": 1
} # default optional params
request_data = {
"instances": [{"prompt": prompt}],
"parameters": optional_params,
}
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = await self.async_handler.post(
url=url,
headers={
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}",
},
data=json.dumps(request_data),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
json_response = response.json()
return self.process_image_generation_response(
json_response, model_response, model
)
def is_image_generation_response(self, json_response: Dict[str, Any]) -> bool:
if "predictions" in json_response:
if "bytesBase64Encoded" in json_response["predictions"][0]:
return True
return False

View file

@ -281,3 +281,33 @@ async def async_embedding(
) )
setattr(model_response, "usage", usage) setattr(model_response, "usage", usage)
return model_response return model_response
async def transform_vertex_response_to_openai(
response: dict, model: str, model_response: litellm.EmbeddingResponse
) -> litellm.EmbeddingResponse:
_predictions = response["predictions"]
embedding_response = []
input_tokens: int = 0
for idx, element in enumerate(_predictions):
embedding = element["embeddings"]
embedding_response.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding["values"],
}
)
input_tokens += embedding["statistics"]["token_count"]
model_response.object = "list"
model_response.data = embedding_response
model_response.model = model
usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
)
setattr(model_response, "usage", usage)
return model_response

View file

@ -126,6 +126,9 @@ from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gem
from .llms.vertex_ai_and_google_ai_studio.gemini_embeddings.batch_embed_content_handler import ( from .llms.vertex_ai_and_google_ai_studio.gemini_embeddings.batch_embed_content_handler import (
GoogleBatchEmbeddings, GoogleBatchEmbeddings,
) )
from .llms.vertex_ai_and_google_ai_studio.image_generation.image_generation_handler import (
VertexImageGeneration,
)
from .llms.vertex_ai_and_google_ai_studio.multimodal_embeddings.embedding_handler import ( from .llms.vertex_ai_and_google_ai_studio.multimodal_embeddings.embedding_handler import (
VertexMultimodalEmbedding, VertexMultimodalEmbedding,
) )
@ -180,6 +183,7 @@ bedrock_converse_chat_completion = BedrockConverseLLM()
bedrock_embedding = BedrockEmbedding() bedrock_embedding = BedrockEmbedding()
vertex_chat_completion = VertexLLM() vertex_chat_completion = VertexLLM()
vertex_multimodal_embedding = VertexMultimodalEmbedding() vertex_multimodal_embedding = VertexMultimodalEmbedding()
vertex_image_generation = VertexImageGeneration()
google_batch_embeddings = GoogleBatchEmbeddings() google_batch_embeddings = GoogleBatchEmbeddings()
vertex_partner_models_chat_completion = VertexAIPartnerModels() vertex_partner_models_chat_completion = VertexAIPartnerModels()
vertex_text_to_speech = VertexTextToSpeechAPI() vertex_text_to_speech = VertexTextToSpeechAPI()
@ -4538,7 +4542,7 @@ def image_generation(
or optional_params.pop("vertex_ai_credentials", None) or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS") or get_secret("VERTEXAI_CREDENTIALS")
) )
model_response = vertex_chat_completion.image_generation( model_response = vertex_image_generation.image_generation(
model=model, model=model,
prompt=prompt, prompt=prompt,
timeout=timeout, timeout=timeout,

View file

@ -1,5 +1,6 @@
import re import re
from datetime import datetime from datetime import datetime
from typing import Union
import httpx import httpx
@ -103,3 +104,52 @@ class PassThroughEndpointLogging:
end_time=end_time, end_time=end_time,
cache_hit=cache_hit, cache_hit=cache_hit,
) )
elif "predict" in url_route:
from litellm.llms.vertex_ai_and_google_ai_studio.image_generation.image_generation_handler import (
VertexImageGeneration,
)
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import (
transform_vertex_response_to_openai,
)
from litellm.types.utils import PassthroughCallTypes
vertex_image_generation_class = VertexImageGeneration()
model = self.extract_model_from_url(url_route)
_json_response = httpx_response.json()
litellm_prediction_response: Union[
litellm.ModelResponse, litellm.EmbeddingResponse, litellm.ImageResponse
] = litellm.ModelResponse()
if vertex_image_generation_class.is_image_generation_response(
_json_response
):
litellm_prediction_response = (
vertex_image_generation_class.process_image_generation_response(
_json_response,
model_response=litellm.ImageResponse(),
model=model,
)
)
logging_obj.call_type = (
PassthroughCallTypes.passthrough_image_generation.value
)
else:
litellm_prediction_response = await transform_vertex_response_to_openai(
response=_json_response,
model=model,
model_response=litellm.EmbeddingResponse(),
)
if isinstance(litellm_prediction_response, litellm.EmbeddingResponse):
litellm_prediction_response.model = model
logging_obj.model = model
logging_obj.model_call_details["model"] = logging_obj.model
await logging_obj.async_success_handler(
result=litellm_prediction_response,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
)

View file

@ -844,7 +844,7 @@ async def _PROXY_track_cost_callback(
kwargs["stream"] == True and "complete_streaming_response" in kwargs kwargs["stream"] == True and "complete_streaming_response" in kwargs
): ):
raise Exception( raise Exception(
f"Model not in litellm model cost map. Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing" f"Model not in litellm model cost map. Passed model = {kwargs.get('model')} - Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
) )
except Exception as e: except Exception as e:
error_msg = f"error in tracking cost callback - {traceback.format_exc()}" error_msg = f"error in tracking cost callback - {traceback.format_exc()}"

View file

@ -70,6 +70,12 @@ def test_get_llm_provider_deepseek_custom_api_base():
os.environ.pop("DEEPSEEK_API_BASE") os.environ.pop("DEEPSEEK_API_BASE")
def test_get_llm_provider_vertex_ai_image_models():
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="imagegeneration@006", custom_llm_provider=None
)
assert custom_llm_provider == "vertex_ai"
def test_get_llm_provider_ai21_chat(): def test_get_llm_provider_ai21_chat():
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider( model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="jamba-1.5-large", model="jamba-1.5-large",
@ -93,3 +99,4 @@ def test_get_llm_provider_ai21_chat_test2():
assert custom_llm_provider == "ai21_chat" assert custom_llm_provider == "ai21_chat"
assert model == "jamba-1.5-large" assert model == "jamba-1.5-large"
assert api_base == "https://api.ai21.com/studio/v1" assert api_base == "https://api.ai21.com/studio/v1"

View file

@ -119,6 +119,10 @@ class CallTypes(Enum):
speech = "speech" speech = "speech"
class PassthroughCallTypes(Enum):
passthrough_image_generation = "passthrough-image-generation"
class TopLogprob(OpenAIObject): class TopLogprob(OpenAIObject):
token: str token: str
"""The token.""" """The token."""

View file

@ -4975,6 +4975,7 @@ def get_llm_provider(
or model in litellm.vertex_language_models or model in litellm.vertex_language_models
or model in litellm.vertex_embedding_models or model in litellm.vertex_embedding_models
or model in litellm.vertex_vision_models or model in litellm.vertex_vision_models
or model in litellm.vertex_ai_image_models
): ):
custom_llm_provider = "vertex_ai" custom_llm_provider = "vertex_ai"
## ai21 ## ai21