diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index aded41148..176902e1a 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -270,6 +270,7 @@ def completion( logging_obj, vertex_project=None, vertex_location=None, + vertex_credentials=None, optional_params=None, litellm_params=None, logger_fn=None, @@ -348,6 +349,7 @@ def completion( print_verbose( f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}" ) + creds, _ = google.auth.default(quota_project_id=vertex_project) print_verbose( f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}" diff --git a/litellm/llms/vertex_ai_anthropic.py b/litellm/llms/vertex_ai_anthropic.py index 3f09e78f9..9bce746dd 100644 --- a/litellm/llms/vertex_ai_anthropic.py +++ b/litellm/llms/vertex_ai_anthropic.py @@ -129,6 +129,18 @@ class VertexAIAnthropicConfig: # makes headers for API call +def refresh_auth( + credentials, +) -> str: # used when user passes in credentials as json string + from google.auth.transport.requests import Request # type: ignore[import-untyped] + + if credentials.token is None: + credentials.refresh(Request()) + + if not credentials.token: + raise RuntimeError("Could not resolve API token from the credentials") + + return credentials.token def completion( @@ -220,6 +232,7 @@ def completion( print_verbose( f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}; vertex_credentials={vertex_credentials}" ) + access_token = None if client is None: if vertex_credentials is not None and isinstance(vertex_credentials, str): import google.oauth2.service_account @@ -228,14 +241,17 @@ def completion( creds = ( google.oauth2.service_account.Credentials.from_service_account_info( - json_obj + json_obj, + scopes=["https://www.googleapis.com/auth/cloud-platform"], ) ) + ### CHECK IF ACCESS + access_token = refresh_auth(credentials=creds) - vertexai.init(credentials=creds) vertex_ai_client = AnthropicVertex( project_id=vertex_project, region=vertex_location, + access_token=access_token, ) else: vertex_ai_client = client @@ -257,6 +273,7 @@ def completion( vertex_location=vertex_location, optional_params=optional_params, client=client, + access_token=access_token, ) else: return async_completion( @@ -270,6 +287,7 @@ def completion( vertex_location=vertex_location, optional_params=optional_params, client=client, + access_token=access_token, ) if stream is not None and stream == True: ## LOGGING @@ -348,12 +366,13 @@ async def async_completion( vertex_location=None, optional_params=None, client=None, + access_token=None, ): from anthropic import AsyncAnthropicVertex if client is None: vertex_ai_client = AsyncAnthropicVertex( - project_id=vertex_project, region=vertex_location + project_id=vertex_project, region=vertex_location, access_token=access_token ) else: vertex_ai_client = client @@ -418,12 +437,13 @@ async def async_streaming( vertex_location=None, optional_params=None, client=None, + access_token=None, ): from anthropic import AsyncAnthropicVertex if client is None: vertex_ai_client = AsyncAnthropicVertex( - project_id=vertex_project, region=vertex_location + project_id=vertex_project, region=vertex_location, access_token=access_token ) else: vertex_ai_client = client diff --git a/litellm/main.py b/litellm/main.py index 3bc52448d..df9eb3e55 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1676,13 +1676,12 @@ def completion( or litellm.vertex_location or get_secret("VERTEXAI_LOCATION") ) - + vertex_credentials = ( + optional_params.pop("vertex_credentials", None) + or optional_params.pop("vertex_ai_credentials", None) + or get_secret("VERTEXAI_CREDENTIALS") + ) if "claude-3" in model: - vertex_credentials = ( - optional_params.pop("vertex_credentials", None) - or optional_params.pop("vertex_ai_credentials", None) - or get_secret("VERTEXAI_CREDENTIALS") - ) model_response = vertex_ai_anthropic.completion( model=model, messages=messages, diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 9b920224e..69a7d8409 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -40,7 +40,7 @@ general_settings: master_key: sk-1234 allow_user_auth: true alerting: ["slack"] - # store_model_in_db: True // set via environment variable - os.environ["STORE_MODEL_IN_DB"] = "True" + store_model_in_db: True // set via environment variable - os.environ["STORE_MODEL_IN_DB"] = "True" proxy_batch_write_at: 5 # 👈 Frequency of batch writing logs to server (in seconds) enable_jwt_auth: True alerting: ["slack"] diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 42e2a8aba..16bddf44f 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -123,8 +123,6 @@ async def get_response(): # reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd." # ) def test_vertex_ai_anthropic(): - load_vertex_ai_credentials() - model = "claude-3-sonnet@20240229" vertex_ai_project = "adroit-crow-413218" @@ -179,12 +177,14 @@ def test_vertex_ai_anthropic_streaming(): # ) @pytest.mark.asyncio async def test_vertex_ai_anthropic_async(): - load_vertex_ai_credentials() + # load_vertex_ai_credentials() model = "claude-3-sonnet@20240229" vertex_ai_project = "adroit-crow-413218" vertex_ai_location = "asia-southeast1" + json_obj = get_vertex_ai_creds_json() + vertex_credentials = json.dumps(json_obj) response = await acompletion( model="vertex_ai/" + model, @@ -192,6 +192,7 @@ async def test_vertex_ai_anthropic_async(): temperature=0.7, vertex_ai_project=vertex_ai_project, vertex_ai_location=vertex_ai_location, + vertex_credentials=vertex_credentials, ) print(f"Model Response: {response}")