diff --git a/litellm/__init__.py b/litellm/__init__.py index e00e4f804e..cc67cd00ca 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -106,7 +106,7 @@ aleph_alpha_key: Optional[str] = None nlp_cloud_key: Optional[str] = None common_cloud_provider_auth_params: dict = { "params": ["project", "region_name", "token"], - "providers": ["vertex_ai", "bedrock", "watsonx", "azure"], + "providers": ["vertex_ai", "bedrock", "watsonx", "azure", "vertex_ai_beta"], } use_client: bool = False ssl_verify: bool = True diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index 2ea0e199e8..62b9085771 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -748,10 +748,12 @@ class VertexLLM(BaseLLM): if project_id is None: project_id = creds.project_id else: - creds, project_id = google_auth.default( + creds, creds_project_id = google_auth.default( quota_project_id=project_id, scopes=["https://www.googleapis.com/auth/cloud-platform"], ) + if project_id is None: + project_id = creds_project_id creds.refresh(Request()) @@ -974,7 +976,7 @@ class VertexLLM(BaseLLM): api_base: Optional[str] = None, ) -> Union[ModelResponse, CustomStreamWrapper]: stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore - + auth_header, url = self._get_token_and_url( model=model, gemini_api_key=gemini_api_key, diff --git a/litellm/utils.py b/litellm/utils.py index dc2c5560da..9390021e0f 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2412,7 +2412,7 @@ def get_optional_params( ): # allow dynamically setting vertex ai init logic continue passed_params[k] = v - + optional_params: Dict = {} common_auth_dict = litellm.common_cloud_provider_auth_params @@ -2431,7 +2431,7 @@ def get_optional_params( non_default_params=passed_params, optional_params=optional_params ) ) - elif custom_llm_provider == "vertex_ai": + elif custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai_beta": optional_params = litellm.VertexAIConfig().map_special_auth_params( non_default_params=passed_params, optional_params=optional_params ) @@ -3914,6 +3914,11 @@ def get_supported_openai_params( return litellm.VertexAIConfig().get_supported_openai_params() elif request_type == "embeddings": return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params() + elif custom_llm_provider == "vertex_ai_beta": + if request_type == "chat_completion": + return litellm.VertexAIConfig().get_supported_openai_params() + elif request_type == "embeddings": + return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params() elif custom_llm_provider == "sagemaker": return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] elif custom_llm_provider == "aleph_alpha":