diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index 11a1e0c6e..1ee9f434a 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -95,7 +95,7 @@ def completion( mode = "" request_str = "" - if model in litellm.vertex_chat_models: + if model in litellm.vertex_chat_models or ("chat" in model): # to catch chat-bison@003 or chat-bison@004 when google will release it chat_model = ChatModel.from_pretrained(model) mode = "chat" request_str += f"chat_model = ChatModel.from_pretrained({model})\n" diff --git a/litellm/main.py b/litellm/main.py index bf35d2a27..6204f0b60 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1136,7 +1136,7 @@ def completion( ) return response response = model_response - elif model in litellm.vertex_chat_models or model in litellm.vertex_code_chat_models or model in litellm.vertex_text_models or model in litellm.vertex_code_text_models: + elif model in litellm.vertex_chat_models or model in litellm.vertex_code_chat_models or model in litellm.vertex_text_models or model in litellm.vertex_code_text_models or custom_llm_provider == "vertex_ai": vertex_ai_project = (litellm.vertex_project or get_secret("VERTEXAI_PROJECT")) vertex_ai_location = (litellm.vertex_location