From b0c1d235beafe18229afc7e261733102cd6a7861 Mon Sep 17 00:00:00 2001 From: Tiger Yu Date: Fri, 28 Jun 2024 10:36:58 -0700 Subject: [PATCH] Include vertex_ai_beta in vertex_ai param mapping --- litellm/__init__.py | 2 +- litellm/llms/vertex_httpx.py | 6 ++++-- litellm/utils.py | 11 ++++++++--- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index a8d9a80a2..381c4c530 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -106,7 +106,7 @@ aleph_alpha_key: Optional[str] = None nlp_cloud_key: Optional[str] = None common_cloud_provider_auth_params: dict = { "params": ["project", "region_name", "token"], - "providers": ["vertex_ai", "bedrock", "watsonx", "azure"], + "providers": ["vertex_ai", "bedrock", "watsonx", "azure", "vertex_ai_beta"], } use_client: bool = False ssl_verify: bool = True diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index 18b1088ba..95a5ed492 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -733,10 +733,12 @@ class VertexLLM(BaseLLM): if project_id is None: project_id = creds.project_id else: - creds, project_id = google_auth.default( + creds, creds_project_id = google_auth.default( quota_project_id=project_id, scopes=["https://www.googleapis.com/auth/cloud-platform"], ) + if project_id is None: + project_id = creds_project_id creds.refresh(Request()) @@ -953,7 +955,7 @@ class VertexLLM(BaseLLM): api_base: Optional[str] = None, ) -> Union[ModelResponse, CustomStreamWrapper]: stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore - + auth_header, url = self._get_token_and_url( model=model, gemini_api_key=gemini_api_key, diff --git a/litellm/utils.py b/litellm/utils.py index c53e8f338..91f390f09 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2324,11 +2324,11 @@ def get_optional_params( elif k == "hf_model_name" and custom_llm_provider != "sagemaker": continue elif ( - k.startswith("vertex_") and custom_llm_provider != "vertex_ai" + k.startswith("vertex_") and custom_llm_provider != "vertex_ai" and custom_llm_provider != "vertex_ai_beta" ): # allow dynamically setting vertex ai init logic continue passed_params[k] = v - + optional_params: Dict = {} common_auth_dict = litellm.common_cloud_provider_auth_params @@ -2347,7 +2347,7 @@ def get_optional_params( non_default_params=passed_params, optional_params=optional_params ) ) - elif custom_llm_provider == "vertex_ai": + elif custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai_beta": optional_params = litellm.VertexAIConfig().map_special_auth_params( non_default_params=passed_params, optional_params=optional_params ) @@ -3826,6 +3826,11 @@ def get_supported_openai_params( return litellm.VertexAIConfig().get_supported_openai_params() elif request_type == "embeddings": return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params() + elif custom_llm_provider == "vertex_ai_beta": + if request_type == "chat_completion": + return litellm.VertexAIConfig().get_supported_openai_params() + elif request_type == "embeddings": + return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params() elif custom_llm_provider == "sagemaker": return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] elif custom_llm_provider == "aleph_alpha":