Fix VertexAI Credential Caching issue (#9756)

* refactor(vertex_llm_base.py): Prevent credential misrouting for projects

Fixes https://github.com/BerriAI/litellm/issues/7904

* fix: passing unit tests

* fix(vertex_llm_base.py): common auth logic across sync + async vertex ai calls

prevents credential caching issue across both flows

* test: fix test

* fix(vertex_llm_base.py): handle project id in default cause

* fix(factory.py): don't pass cache control if not set

bedrock invoke does not support this

* test: fix test

* fix(vertex_llm_base.py): add .exception message in load_auth

* fix: fix ruff error
This commit is contained in:
Krish Dholakia 2025-04-04 16:38:08 -07:00 committed by GitHub
parent cdd351a03b
commit e1f7bcb47d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 290 additions and 57 deletions

View file

@ -121,6 +121,7 @@ class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
kwargs
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],

View file

@ -6,7 +6,7 @@ Handles Authentication and generating request urls for Vertex AI and Google AI S
import json
import os
from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.asyncify import asyncify
@ -28,6 +28,10 @@ class VertexBase(BaseLLM):
self.access_token: Optional[str] = None
self.refresh_token: Optional[str] = None
self._credentials: Optional[GoogleCredentialsObject] = None
self._credentials_project_mapping: Dict[
Tuple[Optional[VERTEX_CREDENTIALS_TYPES], Optional[str]],
GoogleCredentialsObject,
] = {}
self.project_id: Optional[str] = None
self.async_handler: Optional[AsyncHTTPHandler] = None
@ -128,32 +132,11 @@ class VertexBase(BaseLLM):
"""
if custom_llm_provider == "gemini":
return "", ""
if self.access_token is not None:
if project_id is not None:
return self.access_token, project_id
elif self.project_id is not None:
return self.access_token, self.project_id
if not self._credentials:
self._credentials, cred_project_id = self.load_auth(
credentials=credentials, project_id=project_id
)
if not self.project_id:
self.project_id = project_id or cred_project_id
else:
if self._credentials.expired or not self._credentials.token:
self.refresh_auth(self._credentials)
if not self.project_id:
self.project_id = self._credentials.quota_project_id
if not self.project_id:
raise ValueError("Could not resolve project_id")
if not self._credentials or not self._credentials.token:
raise RuntimeError("Could not resolve API token from the environment")
return self._credentials.token, project_id or self.project_id
return self.get_access_token(
credentials=credentials,
project_id=project_id,
)
def is_using_v1beta1_features(self, optional_params: dict) -> bool:
"""
@ -259,6 +242,101 @@ class VertexBase(BaseLLM):
url=url,
)
def get_access_token(
self,
credentials: Optional[VERTEX_CREDENTIALS_TYPES],
project_id: Optional[str],
) -> Tuple[str, str]:
"""
Get access token and project id
1. Check if credentials are already in self._credentials_project_mapping
2. If not, load credentials and add to self._credentials_project_mapping
3. Check if loaded credentials have expired
4. If expired, refresh credentials
5. Return access token and project id
"""
# Convert dict credentials to string for caching
cache_credentials = (
json.dumps(credentials) if isinstance(credentials, dict) else credentials
)
credential_cache_key = (cache_credentials, project_id)
_credentials: Optional[GoogleCredentialsObject] = None
verbose_logger.debug(
f"Checking cached credentials for project_id: {project_id}"
)
if credential_cache_key in self._credentials_project_mapping:
verbose_logger.debug(
f"Cached credentials found for project_id: {project_id}."
)
_credentials = self._credentials_project_mapping[credential_cache_key]
verbose_logger.debug("Using cached credentials")
credential_project_id = _credentials.quota_project_id or getattr(
_credentials, "project_id", None
)
else:
verbose_logger.debug(
f"Credential cache key not found for project_id: {project_id}, loading new credentials"
)
try:
_credentials, credential_project_id = self.load_auth(
credentials=credentials, project_id=project_id
)
except Exception as e:
verbose_logger.exception(
"Failed to load vertex credentials. Check to see if credentials containing partial/invalid information."
)
raise e
if _credentials is None:
raise ValueError(
"Could not resolve credentials - either dynamically or from environment, for project_id: {}".format(
project_id
)
)
self._credentials_project_mapping[credential_cache_key] = _credentials
## VALIDATE CREDENTIALS
verbose_logger.debug(f"Validating credentials for project_id: {project_id}")
if (
project_id is not None
and credential_project_id
and credential_project_id != project_id
):
raise ValueError(
"Could not resolve project_id. Credential project_id: {} does not match requested project_id: {}".format(
_credentials.quota_project_id, project_id
)
)
elif (
project_id is None
and credential_project_id is not None
and isinstance(credential_project_id, str)
):
project_id = credential_project_id
if _credentials.expired:
self.refresh_auth(_credentials)
## VALIDATION STEP
if _credentials.token is None or not isinstance(_credentials.token, str):
raise ValueError(
"Could not resolve credentials token. Got None or non-string token - {}".format(
_credentials.token
)
)
if project_id is None:
raise ValueError("Could not resolve project_id")
return _credentials.token, project_id
async def _ensure_access_token_async(
self,
credentials: Optional[VERTEX_CREDENTIALS_TYPES],
@ -272,38 +350,14 @@ class VertexBase(BaseLLM):
"""
if custom_llm_provider == "gemini":
return "", ""
if self.access_token is not None:
if project_id is not None:
return self.access_token, project_id
elif self.project_id is not None:
return self.access_token, self.project_id
if not self._credentials:
try:
self._credentials, cred_project_id = await asyncify(self.load_auth)(
credentials=credentials, project_id=project_id
)
except Exception:
verbose_logger.exception(
"Failed to load vertex credentials. Check to see if credentials containing partial/invalid information."
)
raise
if not self.project_id:
self.project_id = project_id or cred_project_id
else:
if self._credentials.expired or not self._credentials.token:
await asyncify(self.refresh_auth)(self._credentials)
if not self.project_id:
self.project_id = self._credentials.quota_project_id
if not self.project_id:
raise ValueError("Could not resolve project_id")
if not self._credentials or not self._credentials.token:
raise RuntimeError("Could not resolve API token from the environment")
return self._credentials.token, project_id or self.project_id
try:
return await asyncify(self.get_access_token)(
credentials=credentials,
project_id=project_id,
)
except Exception as e:
raise e
def set_headers(
self, auth_header: Optional[str], extra_headers: Optional[dict]

View file

@ -0,0 +1,176 @@
import os
import sys
from unittest.mock import MagicMock, call, patch
import pytest
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
import litellm
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
def run_sync(coro):
"""Helper to run coroutine synchronously for testing"""
import asyncio
return asyncio.run(coro)
class TestVertexBase:
@pytest.mark.parametrize("is_async", [True, False], ids=["async", "sync"])
@pytest.mark.asyncio
async def test_credential_project_validation(self, is_async):
vertex_base = VertexBase()
# Mock credentials with project_id "project-1"
mock_creds = MagicMock()
mock_creds.project_id = "project-1"
mock_creds.token = "fake-token-1"
mock_creds.expired = False
mock_creds.quota_project_id = "project-1"
# Test case 1: Ensure credentials match project
with patch.object(
vertex_base, "load_auth", return_value=(mock_creds, "project-1")
):
if is_async:
token, project = await vertex_base._ensure_access_token_async(
credentials={"type": "service_account", "project_id": "project-1"},
project_id="project-1",
custom_llm_provider="vertex_ai",
)
else:
token, project = vertex_base._ensure_access_token(
credentials={"type": "service_account", "project_id": "project-1"},
project_id="project-1",
custom_llm_provider="vertex_ai",
)
assert project == "project-1"
assert token == "fake-token-1"
# Test case 2: Prevent using credentials from different project
with patch.object(
vertex_base, "load_auth", return_value=(mock_creds, "project-1")
):
with pytest.raises(ValueError, match="Could not resolve project_id"):
if is_async:
result = await vertex_base._ensure_access_token_async(
credentials={"type": "service_account"},
project_id="different-project",
custom_llm_provider="vertex_ai",
)
else:
result = vertex_base._ensure_access_token(
credentials={"type": "service_account"},
project_id="different-project",
custom_llm_provider="vertex_ai",
)
print(f"result: {result}")
@pytest.mark.parametrize("is_async", [True, False], ids=["async", "sync"])
@pytest.mark.asyncio
async def test_cached_credentials(self, is_async):
vertex_base = VertexBase()
# Initial credentials
mock_creds = MagicMock()
mock_creds.token = "token-1"
mock_creds.expired = False
mock_creds.project_id = "project-1"
mock_creds.quota_project_id = "project-1"
# Test initial credential load and caching
with patch.object(
vertex_base, "load_auth", return_value=(mock_creds, "project-1")
):
# First call should load credentials
if is_async:
token, project = await vertex_base._ensure_access_token_async(
credentials={"type": "service_account"},
project_id="project-1",
custom_llm_provider="vertex_ai",
)
else:
token, project = vertex_base._ensure_access_token(
credentials={"type": "service_account"},
project_id="project-1",
custom_llm_provider="vertex_ai",
)
assert token == "token-1"
# Second call should use cached credentials
if is_async:
token2, project2 = await vertex_base._ensure_access_token_async(
credentials={"type": "service_account"},
project_id="project-1",
custom_llm_provider="vertex_ai",
)
else:
token2, project2 = vertex_base._ensure_access_token(
credentials={"type": "service_account"},
project_id="project-1",
custom_llm_provider="vertex_ai",
)
assert token2 == "token-1"
assert project2 == "project-1"
@pytest.mark.parametrize("is_async", [True, False], ids=["async", "sync"])
@pytest.mark.asyncio
async def test_credential_refresh(self, is_async):
vertex_base = VertexBase()
# Create expired credentials
mock_creds = MagicMock()
mock_creds.token = "my-token"
mock_creds.expired = True
mock_creds.project_id = "project-1"
mock_creds.quota_project_id = "project-1"
with patch.object(
vertex_base, "load_auth", return_value=(mock_creds, "project-1")
), patch.object(vertex_base, "refresh_auth") as mock_refresh:
def mock_refresh_impl(creds):
creds.token = "refreshed-token"
creds.expired = False
mock_refresh.side_effect = mock_refresh_impl
if is_async:
token, project = await vertex_base._ensure_access_token_async(
credentials={"type": "service_account"},
project_id="project-1",
custom_llm_provider="vertex_ai",
)
else:
token, project = vertex_base._ensure_access_token(
credentials={"type": "service_account"},
project_id="project-1",
custom_llm_provider="vertex_ai",
)
assert mock_refresh.called
assert token == "refreshed-token"
assert not mock_creds.expired
@pytest.mark.parametrize("is_async", [True, False], ids=["async", "sync"])
@pytest.mark.asyncio
async def test_gemini_credentials(self, is_async):
vertex_base = VertexBase()
# Test that Gemini requests bypass credential checks
if is_async:
token, project = await vertex_base._ensure_access_token_async(
credentials=None, project_id=None, custom_llm_provider="gemini"
)
else:
token, project = vertex_base._ensure_access_token(
credentials=None, project_id=None, custom_llm_provider="gemini"
)
assert token == ""
assert project == ""

View file

@ -453,6 +453,7 @@ async def test_async_vertexai_response():
or "ultra" in model
or "002" in model
or "gemini-2.0-flash-thinking-exp" in model
or "gemini-2.0-pro-exp-02-05" in model
):
# our account does not have access to this model
continue
@ -499,6 +500,7 @@ async def test_async_vertexai_streaming_response():
or "ultra" in model
or "002" in model
or "gemini-2.0-flash-thinking-exp" in model
or "gemini-2.0-pro-exp-02-05" in model
):
# our account does not have access to this model
continue