fix get_llm_provider for imagegeneration@006

This commit is contained in:
Ishaan Jaff 2024-09-02 17:47:29 -07:00
parent 811aa34a36
commit 80dd2cfc7f
2 changed files with 11 additions and 0 deletions

View file

@ -355,6 +355,7 @@ vertex_language_models: List = []
vertex_vision_models: List = []
vertex_chat_models: List = []
vertex_code_chat_models: List = []
vertex_ai_image_models: List = []
vertex_text_models: List = []
vertex_code_text_models: List = []
vertex_embedding_models: List = []
@ -414,6 +415,9 @@ for key, value in model_cost.items():
elif value.get("litellm_provider") == "vertex_ai-ai21_models":
key = key.replace("vertex_ai/", "")
vertex_ai_ai21_models.append(key)
elif value.get("litellm_provider") == "vertex_ai-image-models":
key = key.replace("vertex_ai/", "")
vertex_ai_image_models.append(key)
elif value.get("litellm_provider") == "ai21":
ai21_models.append(key)
elif value.get("litellm_provider") == "nlp_cloud":

View file

@ -68,3 +68,10 @@ def test_get_llm_provider_deepseek_custom_api_base():
assert api_base == "MY-FAKE-BASE"
os.environ.pop("DEEPSEEK_API_BASE")
def test_get_llm_provider_vertex_ai_image_models():
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="imagegeneration@006",
)
assert custom_llm_provider == "vertex_ai"