From 060ac995d60f0368013e7ba87753533e254dbe72 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 16 Apr 2024 17:34:25 -0700 Subject: [PATCH] fix(vertex_ai.py): accept credentials as a json string --- litellm/llms/vertex_ai.py | 24 ++++++++++++++++++++++-- litellm/main.py | 7 +++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index 176902e1a..69feef63c 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -349,8 +349,17 @@ def completion( print_verbose( f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}" ) + if vertex_credentials is not None and isinstance(vertex_credentials, str): + import google.oauth2.service_account - creds, _ = google.auth.default(quota_project_id=vertex_project) + json_obj = json.loads(vertex_credentials) + + creds = google.oauth2.service_account.Credentials.from_service_account_info( + json_obj, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + else: + creds, _ = google.auth.default(quota_project_id=vertex_project) print_verbose( f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}" ) @@ -1171,6 +1180,7 @@ def embedding( encoding=None, vertex_project=None, vertex_location=None, + vertex_credentials=None, aembedding=False, print_verbose=None, ): @@ -1191,7 +1201,17 @@ def embedding( print_verbose( f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}" ) - creds, _ = google.auth.default(quota_project_id=vertex_project) + if vertex_credentials is not None and isinstance(vertex_credentials, str): + import google.oauth2.service_account + + json_obj = json.loads(vertex_credentials) + + creds = google.oauth2.service_account.Credentials.from_service_account_info( + json_obj, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + else: + creds, _ = google.auth.default(quota_project_id=vertex_project) print_verbose( f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}" ) diff --git a/litellm/main.py b/litellm/main.py index 4c4a9540e..593fc7eae 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1710,6 +1710,7 @@ def completion( encoding=encoding, vertex_location=vertex_ai_location, vertex_project=vertex_ai_project, + vertex_credentials=vertex_credentials, logging_obj=logging, acompletion=acompletion, ) @@ -2807,6 +2808,11 @@ def embedding( or litellm.vertex_location or get_secret("VERTEXAI_LOCATION") ) + vertex_credentials = ( + optional_params.pop("vertex_credentials", None) + or optional_params.pop("vertex_ai_credentials", None) + or get_secret("VERTEXAI_CREDENTIALS") + ) response = vertex_ai.embedding( model=model, @@ -2817,6 +2823,7 @@ def embedding( model_response=EmbeddingResponse(), vertex_project=vertex_ai_project, vertex_location=vertex_ai_location, + vertex_credentials=vertex_credentials, aembedding=aembedding, print_verbose=print_verbose, )