fix(vertex_ai.py): accept credentials as a json string

This commit is contained in:
Krrish Dholakia 2024-04-16 17:34:25 -07:00
parent fed8d61933
commit 060ac995d6
2 changed files with 29 additions and 2 deletions

View file

@ -349,7 +349,16 @@ def completion(
print_verbose( print_verbose(
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}" f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
) )
if vertex_credentials is not None and isinstance(vertex_credentials, str):
import google.oauth2.service_account
json_obj = json.loads(vertex_credentials)
creds = google.oauth2.service_account.Credentials.from_service_account_info(
json_obj,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
else:
creds, _ = google.auth.default(quota_project_id=vertex_project) creds, _ = google.auth.default(quota_project_id=vertex_project)
print_verbose( print_verbose(
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}" f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
@ -1171,6 +1180,7 @@ def embedding(
encoding=None, encoding=None,
vertex_project=None, vertex_project=None,
vertex_location=None, vertex_location=None,
vertex_credentials=None,
aembedding=False, aembedding=False,
print_verbose=None, print_verbose=None,
): ):
@ -1191,6 +1201,16 @@ def embedding(
print_verbose( print_verbose(
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}" f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
) )
if vertex_credentials is not None and isinstance(vertex_credentials, str):
import google.oauth2.service_account
json_obj = json.loads(vertex_credentials)
creds = google.oauth2.service_account.Credentials.from_service_account_info(
json_obj,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
else:
creds, _ = google.auth.default(quota_project_id=vertex_project) creds, _ = google.auth.default(quota_project_id=vertex_project)
print_verbose( print_verbose(
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}" f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"

View file

@ -1710,6 +1710,7 @@ def completion(
encoding=encoding, encoding=encoding,
vertex_location=vertex_ai_location, vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project, vertex_project=vertex_ai_project,
vertex_credentials=vertex_credentials,
logging_obj=logging, logging_obj=logging,
acompletion=acompletion, acompletion=acompletion,
) )
@ -2807,6 +2808,11 @@ def embedding(
or litellm.vertex_location or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION") or get_secret("VERTEXAI_LOCATION")
) )
vertex_credentials = (
optional_params.pop("vertex_credentials", None)
or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS")
)
response = vertex_ai.embedding( response = vertex_ai.embedding(
model=model, model=model,
@ -2817,6 +2823,7 @@ def embedding(
model_response=EmbeddingResponse(), model_response=EmbeddingResponse(),
vertex_project=vertex_ai_project, vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location, vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
aembedding=aembedding, aembedding=aembedding,
print_verbose=print_verbose, print_verbose=print_verbose,
) )