fix(main.py): accept vertex service account credentials as json string

allows us to dynamically set vertex ai credentials
This commit is contained in:
Krrish Dholakia 2024-04-15 13:28:59 -07:00
parent c769c3f39b
commit 50081479f9
3 changed files with 71 additions and 15 deletions

View file

@ -140,6 +140,7 @@ def completion(
logging_obj,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
optional_params=None,
litellm_params=None,
logger_fn=None,
@ -217,11 +218,24 @@ def completion(
## Completion Call
print_verbose(
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}; vertex_credentials={vertex_credentials}"
)
if client is None:
if vertex_credentials is not None and isinstance(vertex_credentials, str):
import google.oauth2.service_account
json_obj = json.loads(vertex_credentials)
creds = (
google.oauth2.service_account.Credentials.from_service_account_info(
json_obj
)
)
vertexai.init(credentials=creds)
vertex_ai_client = AnthropicVertex(
project_id=vertex_project, region=vertex_location
project_id=vertex_project,
region=vertex_location,
)
else:
vertex_ai_client = client

View file

@ -1678,6 +1678,11 @@ def completion(
)
if "claude-3" in model:
vertex_credentials = (
optional_params.pop("vertex_credentials", None)
or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS")
)
model_response = vertex_ai_anthropic.completion(
model=model,
messages=messages,
@ -1689,6 +1694,7 @@ def completion(
encoding=encoding,
vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project,
vertex_credentials=vertex_credentials,
logging_obj=logging,
acompletion=acompletion,
)

View file

@ -23,6 +23,40 @@ user_message = "Write a short poem about the sky"
messages = [{"content": user_message, "role": "user"}]
def get_vertex_ai_creds_json() -> dict:
# Define the path to the vertex_key.json file
print("loading vertex ai credentials")
filepath = os.path.dirname(os.path.abspath(__file__))
vertex_key_path = filepath + "/vertex_key.json"
# Read the existing content of the file or create an empty dictionary
try:
with open(vertex_key_path, "r") as file:
# Read the file content
print("Read vertexai file path")
content = file.read()
# If the file is empty or not valid JSON, create an empty dictionary
if not content or not content.strip():
service_account_key_data = {}
else:
# Attempt to load the existing JSON content
file.seek(0)
service_account_key_data = json.load(file)
except FileNotFoundError:
# If the file doesn't exist, create an empty dictionary
service_account_key_data = {}
# Update the service_account_key_data with environment variables
private_key_id = os.environ.get("VERTEX_AI_PRIVATE_KEY_ID", "")
private_key = os.environ.get("VERTEX_AI_PRIVATE_KEY", "")
private_key = private_key.replace("\\n", "\n")
service_account_key_data["private_key_id"] = private_key_id
service_account_key_data["private_key"] = private_key
return service_account_key_data
def load_vertex_ai_credentials():
# Define the path to the vertex_key.json file
print("loading vertex ai credentials")
@ -85,9 +119,9 @@ async def get_response():
pytest.fail(f"An error occurred - {str(e)}")
@pytest.mark.skip(
reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
)
# @pytest.mark.skip(
# reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
# )
def test_vertex_ai_anthropic():
load_vertex_ai_credentials()
@ -95,6 +129,8 @@ def test_vertex_ai_anthropic():
vertex_ai_project = "adroit-crow-413218"
vertex_ai_location = "asia-southeast1"
json_obj = get_vertex_ai_creds_json()
vertex_credentials = json.dumps(json_obj)
response = completion(
model="vertex_ai/" + model,
@ -102,13 +138,14 @@ def test_vertex_ai_anthropic():
temperature=0.7,
vertex_ai_project=vertex_ai_project,
vertex_ai_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
)
print("\nModel Response", response)
@pytest.mark.skip(
reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
)
# @pytest.mark.skip(
# reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
# )
def test_vertex_ai_anthropic_streaming():
load_vertex_ai_credentials()
@ -137,9 +174,9 @@ def test_vertex_ai_anthropic_streaming():
# test_vertex_ai_anthropic_streaming()
@pytest.mark.skip(
reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
)
# @pytest.mark.skip(
# reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
# )
@pytest.mark.asyncio
async def test_vertex_ai_anthropic_async():
load_vertex_ai_credentials()
@ -162,9 +199,9 @@ async def test_vertex_ai_anthropic_async():
# asyncio.run(test_vertex_ai_anthropic_async())
@pytest.mark.skip(
reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
)
# @pytest.mark.skip(
# reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
# )
@pytest.mark.asyncio
async def test_vertex_ai_anthropic_async_streaming():
load_vertex_ai_credentials()
@ -180,7 +217,6 @@ async def test_vertex_ai_anthropic_async_streaming():
temperature=0.7,
vertex_ai_project=vertex_ai_project,
vertex_ai_location=vertex_ai_location,
stream=True,
)
async for chunk in response: