diff --git a/litellm/llms/vertex_ai_anthropic.py b/litellm/llms/vertex_ai_anthropic.py index eb4cc864f..3f09e78f9 100644 --- a/litellm/llms/vertex_ai_anthropic.py +++ b/litellm/llms/vertex_ai_anthropic.py @@ -140,6 +140,7 @@ def completion( logging_obj, vertex_project=None, vertex_location=None, + vertex_credentials=None, optional_params=None, litellm_params=None, logger_fn=None, @@ -217,11 +218,24 @@ def completion( ## Completion Call print_verbose( - f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}" + f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}; vertex_credentials={vertex_credentials}" ) if client is None: + if vertex_credentials is not None and isinstance(vertex_credentials, str): + import google.oauth2.service_account + + json_obj = json.loads(vertex_credentials) + + creds = ( + google.oauth2.service_account.Credentials.from_service_account_info( + json_obj + ) + ) + + vertexai.init(credentials=creds) vertex_ai_client = AnthropicVertex( - project_id=vertex_project, region=vertex_location + project_id=vertex_project, + region=vertex_location, ) else: vertex_ai_client = client diff --git a/litellm/main.py b/litellm/main.py index f23347942..3bc52448d 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1678,6 +1678,11 @@ def completion( ) if "claude-3" in model: + vertex_credentials = ( + optional_params.pop("vertex_credentials", None) + or optional_params.pop("vertex_ai_credentials", None) + or get_secret("VERTEXAI_CREDENTIALS") + ) model_response = vertex_ai_anthropic.completion( model=model, messages=messages, @@ -1689,6 +1694,7 @@ def completion( encoding=encoding, vertex_location=vertex_ai_location, vertex_project=vertex_ai_project, + vertex_credentials=vertex_credentials, logging_obj=logging, acompletion=acompletion, ) diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 5df1085ae..42e2a8aba 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -23,6 +23,40 @@ user_message = "Write a short poem about the sky" messages = [{"content": user_message, "role": "user"}] +def get_vertex_ai_creds_json() -> dict: + # Define the path to the vertex_key.json file + print("loading vertex ai credentials") + filepath = os.path.dirname(os.path.abspath(__file__)) + vertex_key_path = filepath + "/vertex_key.json" + + # Read the existing content of the file or create an empty dictionary + try: + with open(vertex_key_path, "r") as file: + # Read the file content + print("Read vertexai file path") + content = file.read() + + # If the file is empty or not valid JSON, create an empty dictionary + if not content or not content.strip(): + service_account_key_data = {} + else: + # Attempt to load the existing JSON content + file.seek(0) + service_account_key_data = json.load(file) + except FileNotFoundError: + # If the file doesn't exist, create an empty dictionary + service_account_key_data = {} + + # Update the service_account_key_data with environment variables + private_key_id = os.environ.get("VERTEX_AI_PRIVATE_KEY_ID", "") + private_key = os.environ.get("VERTEX_AI_PRIVATE_KEY", "") + private_key = private_key.replace("\\n", "\n") + service_account_key_data["private_key_id"] = private_key_id + service_account_key_data["private_key"] = private_key + + return service_account_key_data + + def load_vertex_ai_credentials(): # Define the path to the vertex_key.json file print("loading vertex ai credentials") @@ -85,9 +119,9 @@ async def get_response(): pytest.fail(f"An error occurred - {str(e)}") -@pytest.mark.skip( - reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd." -) +# @pytest.mark.skip( +# reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd." +# ) def test_vertex_ai_anthropic(): load_vertex_ai_credentials() @@ -95,6 +129,8 @@ def test_vertex_ai_anthropic(): vertex_ai_project = "adroit-crow-413218" vertex_ai_location = "asia-southeast1" + json_obj = get_vertex_ai_creds_json() + vertex_credentials = json.dumps(json_obj) response = completion( model="vertex_ai/" + model, @@ -102,13 +138,14 @@ def test_vertex_ai_anthropic(): temperature=0.7, vertex_ai_project=vertex_ai_project, vertex_ai_location=vertex_ai_location, + vertex_credentials=vertex_credentials, ) print("\nModel Response", response) -@pytest.mark.skip( - reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd." -) +# @pytest.mark.skip( +# reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd." +# ) def test_vertex_ai_anthropic_streaming(): load_vertex_ai_credentials() @@ -137,9 +174,9 @@ def test_vertex_ai_anthropic_streaming(): # test_vertex_ai_anthropic_streaming() -@pytest.mark.skip( - reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd." -) +# @pytest.mark.skip( +# reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd." +# ) @pytest.mark.asyncio async def test_vertex_ai_anthropic_async(): load_vertex_ai_credentials() @@ -162,9 +199,9 @@ async def test_vertex_ai_anthropic_async(): # asyncio.run(test_vertex_ai_anthropic_async()) -@pytest.mark.skip( - reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd." -) +# @pytest.mark.skip( +# reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd." +# ) @pytest.mark.asyncio async def test_vertex_ai_anthropic_async_streaming(): load_vertex_ai_credentials() @@ -180,7 +217,6 @@ async def test_vertex_ai_anthropic_async_streaming(): temperature=0.7, vertex_ai_project=vertex_ai_project, vertex_ai_location=vertex_ai_location, - stream=True, ) async for chunk in response: