Merge pull request #4461 from t968914/litellm-fix-vertexaibeta

fix: Include vertex_ai_beta in vertex_ai param mapping/Do not use google auth project_id
This commit is contained in:
Krish Dholakia 2024-07-04 15:27:20 -07:00 committed by GitHub
commit 18d8edc145
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 12 additions and 5 deletions

View file

@ -106,7 +106,7 @@ aleph_alpha_key: Optional[str] = None
nlp_cloud_key: Optional[str] = None nlp_cloud_key: Optional[str] = None
common_cloud_provider_auth_params: dict = { common_cloud_provider_auth_params: dict = {
"params": ["project", "region_name", "token"], "params": ["project", "region_name", "token"],
"providers": ["vertex_ai", "bedrock", "watsonx", "azure"], "providers": ["vertex_ai", "bedrock", "watsonx", "azure", "vertex_ai_beta"],
} }
use_client: bool = False use_client: bool = False
ssl_verify: bool = True ssl_verify: bool = True

View file

@ -748,10 +748,12 @@ class VertexLLM(BaseLLM):
if project_id is None: if project_id is None:
project_id = creds.project_id project_id = creds.project_id
else: else:
creds, project_id = google_auth.default( creds, creds_project_id = google_auth.default(
quota_project_id=project_id, quota_project_id=project_id,
scopes=["https://www.googleapis.com/auth/cloud-platform"], scopes=["https://www.googleapis.com/auth/cloud-platform"],
) )
if project_id is None:
project_id = creds_project_id
creds.refresh(Request()) creds.refresh(Request())
@ -974,7 +976,7 @@ class VertexLLM(BaseLLM):
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> Union[ModelResponse, CustomStreamWrapper]: ) -> Union[ModelResponse, CustomStreamWrapper]:
stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore
auth_header, url = self._get_token_and_url( auth_header, url = self._get_token_and_url(
model=model, model=model,
gemini_api_key=gemini_api_key, gemini_api_key=gemini_api_key,

View file

@ -2412,7 +2412,7 @@ def get_optional_params(
): # allow dynamically setting vertex ai init logic ): # allow dynamically setting vertex ai init logic
continue continue
passed_params[k] = v passed_params[k] = v
optional_params: Dict = {} optional_params: Dict = {}
common_auth_dict = litellm.common_cloud_provider_auth_params common_auth_dict = litellm.common_cloud_provider_auth_params
@ -2431,7 +2431,7 @@ def get_optional_params(
non_default_params=passed_params, optional_params=optional_params non_default_params=passed_params, optional_params=optional_params
) )
) )
elif custom_llm_provider == "vertex_ai": elif custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai_beta":
optional_params = litellm.VertexAIConfig().map_special_auth_params( optional_params = litellm.VertexAIConfig().map_special_auth_params(
non_default_params=passed_params, optional_params=optional_params non_default_params=passed_params, optional_params=optional_params
) )
@ -3914,6 +3914,11 @@ def get_supported_openai_params(
return litellm.VertexAIConfig().get_supported_openai_params() return litellm.VertexAIConfig().get_supported_openai_params()
elif request_type == "embeddings": elif request_type == "embeddings":
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params() return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "vertex_ai_beta":
if request_type == "chat_completion":
return litellm.VertexAIConfig().get_supported_openai_params()
elif request_type == "embeddings":
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "sagemaker": elif custom_llm_provider == "sagemaker":
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
elif custom_llm_provider == "aleph_alpha": elif custom_llm_provider == "aleph_alpha":