diff --git a/litellm/__init__.py b/litellm/__init__.py index 3f22e41b6f..4148fdc78c 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -355,6 +355,7 @@ vertex_language_models: List = [] vertex_vision_models: List = [] vertex_chat_models: List = [] vertex_code_chat_models: List = [] +vertex_ai_image_models: List = [] vertex_text_models: List = [] vertex_code_text_models: List = [] vertex_embedding_models: List = [] @@ -414,6 +415,9 @@ for key, value in model_cost.items(): elif value.get("litellm_provider") == "vertex_ai-ai21_models": key = key.replace("vertex_ai/", "") vertex_ai_ai21_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-image-models": + key = key.replace("vertex_ai/", "") + vertex_ai_image_models.append(key) elif value.get("litellm_provider") == "ai21": ai21_models.append(key) elif value.get("litellm_provider") == "nlp_cloud": diff --git a/litellm/tests/test_get_llm_provider.py b/litellm/tests/test_get_llm_provider.py index 5e1c1f4fec..4eef036a70 100644 --- a/litellm/tests/test_get_llm_provider.py +++ b/litellm/tests/test_get_llm_provider.py @@ -68,3 +68,10 @@ def test_get_llm_provider_deepseek_custom_api_base(): assert api_base == "MY-FAKE-BASE" os.environ.pop("DEEPSEEK_API_BASE") + + +def test_get_llm_provider_vertex_ai_image_models(): + model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider( + model="imagegeneration@006", + ) + assert custom_llm_provider == "vertex_ai"