mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
Fix VertexAI Credential Caching issue (#9756)
* refactor(vertex_llm_base.py): Prevent credential misrouting for projects Fixes https://github.com/BerriAI/litellm/issues/7904 * fix: passing unit tests * fix(vertex_llm_base.py): common auth logic across sync + async vertex ai calls prevents credential caching issue across both flows * test: fix test * fix(vertex_llm_base.py): handle project id in default cause * fix(factory.py): don't pass cache control if not set bedrock invoke does not support this * test: fix test * fix(vertex_llm_base.py): add .exception message in load_auth * fix: fix ruff error
This commit is contained in:
parent
cdd351a03b
commit
e1f7bcb47d
4 changed files with 290 additions and 57 deletions
|
@ -121,6 +121,7 @@ class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
|
|||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
||||
kwargs
|
||||
)
|
||||
|
||||
headers = await self.construct_request_headers(
|
||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||
service_account_json=gcs_logging_config["path_service_account"],
|
||||
|
|
|
@ -6,7 +6,7 @@ Handles Authentication and generating request urls for Vertex AI and Google AI S
|
|||
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.asyncify import asyncify
|
||||
|
@ -28,6 +28,10 @@ class VertexBase(BaseLLM):
|
|||
self.access_token: Optional[str] = None
|
||||
self.refresh_token: Optional[str] = None
|
||||
self._credentials: Optional[GoogleCredentialsObject] = None
|
||||
self._credentials_project_mapping: Dict[
|
||||
Tuple[Optional[VERTEX_CREDENTIALS_TYPES], Optional[str]],
|
||||
GoogleCredentialsObject,
|
||||
] = {}
|
||||
self.project_id: Optional[str] = None
|
||||
self.async_handler: Optional[AsyncHTTPHandler] = None
|
||||
|
||||
|
@ -128,32 +132,11 @@ class VertexBase(BaseLLM):
|
|||
"""
|
||||
if custom_llm_provider == "gemini":
|
||||
return "", ""
|
||||
if self.access_token is not None:
|
||||
if project_id is not None:
|
||||
return self.access_token, project_id
|
||||
elif self.project_id is not None:
|
||||
return self.access_token, self.project_id
|
||||
|
||||
if not self._credentials:
|
||||
self._credentials, cred_project_id = self.load_auth(
|
||||
credentials=credentials, project_id=project_id
|
||||
)
|
||||
if not self.project_id:
|
||||
self.project_id = project_id or cred_project_id
|
||||
else:
|
||||
if self._credentials.expired or not self._credentials.token:
|
||||
self.refresh_auth(self._credentials)
|
||||
|
||||
if not self.project_id:
|
||||
self.project_id = self._credentials.quota_project_id
|
||||
|
||||
if not self.project_id:
|
||||
raise ValueError("Could not resolve project_id")
|
||||
|
||||
if not self._credentials or not self._credentials.token:
|
||||
raise RuntimeError("Could not resolve API token from the environment")
|
||||
|
||||
return self._credentials.token, project_id or self.project_id
|
||||
return self.get_access_token(
|
||||
credentials=credentials,
|
||||
project_id=project_id,
|
||||
)
|
||||
|
||||
def is_using_v1beta1_features(self, optional_params: dict) -> bool:
|
||||
"""
|
||||
|
@ -259,6 +242,101 @@ class VertexBase(BaseLLM):
|
|||
url=url,
|
||||
)
|
||||
|
||||
def get_access_token(
|
||||
self,
|
||||
credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||
project_id: Optional[str],
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Get access token and project id
|
||||
|
||||
1. Check if credentials are already in self._credentials_project_mapping
|
||||
2. If not, load credentials and add to self._credentials_project_mapping
|
||||
3. Check if loaded credentials have expired
|
||||
4. If expired, refresh credentials
|
||||
5. Return access token and project id
|
||||
"""
|
||||
|
||||
# Convert dict credentials to string for caching
|
||||
cache_credentials = (
|
||||
json.dumps(credentials) if isinstance(credentials, dict) else credentials
|
||||
)
|
||||
credential_cache_key = (cache_credentials, project_id)
|
||||
_credentials: Optional[GoogleCredentialsObject] = None
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Checking cached credentials for project_id: {project_id}"
|
||||
)
|
||||
|
||||
if credential_cache_key in self._credentials_project_mapping:
|
||||
verbose_logger.debug(
|
||||
f"Cached credentials found for project_id: {project_id}."
|
||||
)
|
||||
_credentials = self._credentials_project_mapping[credential_cache_key]
|
||||
verbose_logger.debug("Using cached credentials")
|
||||
credential_project_id = _credentials.quota_project_id or getattr(
|
||||
_credentials, "project_id", None
|
||||
)
|
||||
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
f"Credential cache key not found for project_id: {project_id}, loading new credentials"
|
||||
)
|
||||
|
||||
try:
|
||||
_credentials, credential_project_id = self.load_auth(
|
||||
credentials=credentials, project_id=project_id
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"Failed to load vertex credentials. Check to see if credentials containing partial/invalid information."
|
||||
)
|
||||
raise e
|
||||
|
||||
if _credentials is None:
|
||||
raise ValueError(
|
||||
"Could not resolve credentials - either dynamically or from environment, for project_id: {}".format(
|
||||
project_id
|
||||
)
|
||||
)
|
||||
|
||||
self._credentials_project_mapping[credential_cache_key] = _credentials
|
||||
|
||||
## VALIDATE CREDENTIALS
|
||||
verbose_logger.debug(f"Validating credentials for project_id: {project_id}")
|
||||
if (
|
||||
project_id is not None
|
||||
and credential_project_id
|
||||
and credential_project_id != project_id
|
||||
):
|
||||
raise ValueError(
|
||||
"Could not resolve project_id. Credential project_id: {} does not match requested project_id: {}".format(
|
||||
_credentials.quota_project_id, project_id
|
||||
)
|
||||
)
|
||||
elif (
|
||||
project_id is None
|
||||
and credential_project_id is not None
|
||||
and isinstance(credential_project_id, str)
|
||||
):
|
||||
project_id = credential_project_id
|
||||
|
||||
if _credentials.expired:
|
||||
self.refresh_auth(_credentials)
|
||||
|
||||
## VALIDATION STEP
|
||||
if _credentials.token is None or not isinstance(_credentials.token, str):
|
||||
raise ValueError(
|
||||
"Could not resolve credentials token. Got None or non-string token - {}".format(
|
||||
_credentials.token
|
||||
)
|
||||
)
|
||||
|
||||
if project_id is None:
|
||||
raise ValueError("Could not resolve project_id")
|
||||
|
||||
return _credentials.token, project_id
|
||||
|
||||
async def _ensure_access_token_async(
|
||||
self,
|
||||
credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||
|
@ -272,38 +350,14 @@ class VertexBase(BaseLLM):
|
|||
"""
|
||||
if custom_llm_provider == "gemini":
|
||||
return "", ""
|
||||
if self.access_token is not None:
|
||||
if project_id is not None:
|
||||
return self.access_token, project_id
|
||||
elif self.project_id is not None:
|
||||
return self.access_token, self.project_id
|
||||
|
||||
if not self._credentials:
|
||||
try:
|
||||
self._credentials, cred_project_id = await asyncify(self.load_auth)(
|
||||
credentials=credentials, project_id=project_id
|
||||
)
|
||||
except Exception:
|
||||
verbose_logger.exception(
|
||||
"Failed to load vertex credentials. Check to see if credentials containing partial/invalid information."
|
||||
)
|
||||
raise
|
||||
if not self.project_id:
|
||||
self.project_id = project_id or cred_project_id
|
||||
else:
|
||||
if self._credentials.expired or not self._credentials.token:
|
||||
await asyncify(self.refresh_auth)(self._credentials)
|
||||
|
||||
if not self.project_id:
|
||||
self.project_id = self._credentials.quota_project_id
|
||||
|
||||
if not self.project_id:
|
||||
raise ValueError("Could not resolve project_id")
|
||||
|
||||
if not self._credentials or not self._credentials.token:
|
||||
raise RuntimeError("Could not resolve API token from the environment")
|
||||
|
||||
return self._credentials.token, project_id or self.project_id
|
||||
try:
|
||||
return await asyncify(self.get_access_token)(
|
||||
credentials=credentials,
|
||||
project_id=project_id,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def set_headers(
|
||||
self, auth_header: Optional[str], extra_headers: Optional[dict]
|
||||
|
|
176
tests/litellm/llms/vertex_ai/test_vertex_llm_base.py
Normal file
176
tests/litellm/llms/vertex_ai/test_vertex_llm_base.py
Normal file
|
@ -0,0 +1,176 @@
|
|||
import os
|
||||
import sys
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
import litellm
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
|
||||
|
||||
def run_sync(coro):
|
||||
"""Helper to run coroutine synchronously for testing"""
|
||||
import asyncio
|
||||
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
class TestVertexBase:
|
||||
@pytest.mark.parametrize("is_async", [True, False], ids=["async", "sync"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_credential_project_validation(self, is_async):
|
||||
vertex_base = VertexBase()
|
||||
|
||||
# Mock credentials with project_id "project-1"
|
||||
mock_creds = MagicMock()
|
||||
mock_creds.project_id = "project-1"
|
||||
mock_creds.token = "fake-token-1"
|
||||
mock_creds.expired = False
|
||||
mock_creds.quota_project_id = "project-1"
|
||||
|
||||
# Test case 1: Ensure credentials match project
|
||||
with patch.object(
|
||||
vertex_base, "load_auth", return_value=(mock_creds, "project-1")
|
||||
):
|
||||
if is_async:
|
||||
token, project = await vertex_base._ensure_access_token_async(
|
||||
credentials={"type": "service_account", "project_id": "project-1"},
|
||||
project_id="project-1",
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
else:
|
||||
token, project = vertex_base._ensure_access_token(
|
||||
credentials={"type": "service_account", "project_id": "project-1"},
|
||||
project_id="project-1",
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
assert project == "project-1"
|
||||
assert token == "fake-token-1"
|
||||
|
||||
# Test case 2: Prevent using credentials from different project
|
||||
with patch.object(
|
||||
vertex_base, "load_auth", return_value=(mock_creds, "project-1")
|
||||
):
|
||||
with pytest.raises(ValueError, match="Could not resolve project_id"):
|
||||
if is_async:
|
||||
result = await vertex_base._ensure_access_token_async(
|
||||
credentials={"type": "service_account"},
|
||||
project_id="different-project",
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
else:
|
||||
result = vertex_base._ensure_access_token(
|
||||
credentials={"type": "service_account"},
|
||||
project_id="different-project",
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
print(f"result: {result}")
|
||||
|
||||
@pytest.mark.parametrize("is_async", [True, False], ids=["async", "sync"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_cached_credentials(self, is_async):
|
||||
vertex_base = VertexBase()
|
||||
|
||||
# Initial credentials
|
||||
mock_creds = MagicMock()
|
||||
mock_creds.token = "token-1"
|
||||
mock_creds.expired = False
|
||||
mock_creds.project_id = "project-1"
|
||||
mock_creds.quota_project_id = "project-1"
|
||||
|
||||
# Test initial credential load and caching
|
||||
with patch.object(
|
||||
vertex_base, "load_auth", return_value=(mock_creds, "project-1")
|
||||
):
|
||||
# First call should load credentials
|
||||
if is_async:
|
||||
token, project = await vertex_base._ensure_access_token_async(
|
||||
credentials={"type": "service_account"},
|
||||
project_id="project-1",
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
else:
|
||||
token, project = vertex_base._ensure_access_token(
|
||||
credentials={"type": "service_account"},
|
||||
project_id="project-1",
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
assert token == "token-1"
|
||||
|
||||
# Second call should use cached credentials
|
||||
if is_async:
|
||||
token2, project2 = await vertex_base._ensure_access_token_async(
|
||||
credentials={"type": "service_account"},
|
||||
project_id="project-1",
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
else:
|
||||
token2, project2 = vertex_base._ensure_access_token(
|
||||
credentials={"type": "service_account"},
|
||||
project_id="project-1",
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
assert token2 == "token-1"
|
||||
assert project2 == "project-1"
|
||||
|
||||
@pytest.mark.parametrize("is_async", [True, False], ids=["async", "sync"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_credential_refresh(self, is_async):
|
||||
vertex_base = VertexBase()
|
||||
|
||||
# Create expired credentials
|
||||
mock_creds = MagicMock()
|
||||
mock_creds.token = "my-token"
|
||||
mock_creds.expired = True
|
||||
mock_creds.project_id = "project-1"
|
||||
mock_creds.quota_project_id = "project-1"
|
||||
|
||||
with patch.object(
|
||||
vertex_base, "load_auth", return_value=(mock_creds, "project-1")
|
||||
), patch.object(vertex_base, "refresh_auth") as mock_refresh:
|
||||
|
||||
def mock_refresh_impl(creds):
|
||||
creds.token = "refreshed-token"
|
||||
creds.expired = False
|
||||
|
||||
mock_refresh.side_effect = mock_refresh_impl
|
||||
|
||||
if is_async:
|
||||
token, project = await vertex_base._ensure_access_token_async(
|
||||
credentials={"type": "service_account"},
|
||||
project_id="project-1",
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
else:
|
||||
token, project = vertex_base._ensure_access_token(
|
||||
credentials={"type": "service_account"},
|
||||
project_id="project-1",
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
assert mock_refresh.called
|
||||
assert token == "refreshed-token"
|
||||
assert not mock_creds.expired
|
||||
|
||||
@pytest.mark.parametrize("is_async", [True, False], ids=["async", "sync"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_credentials(self, is_async):
|
||||
vertex_base = VertexBase()
|
||||
|
||||
# Test that Gemini requests bypass credential checks
|
||||
if is_async:
|
||||
token, project = await vertex_base._ensure_access_token_async(
|
||||
credentials=None, project_id=None, custom_llm_provider="gemini"
|
||||
)
|
||||
else:
|
||||
token, project = vertex_base._ensure_access_token(
|
||||
credentials=None, project_id=None, custom_llm_provider="gemini"
|
||||
)
|
||||
assert token == ""
|
||||
assert project == ""
|
|
@ -453,6 +453,7 @@ async def test_async_vertexai_response():
|
|||
or "ultra" in model
|
||||
or "002" in model
|
||||
or "gemini-2.0-flash-thinking-exp" in model
|
||||
or "gemini-2.0-pro-exp-02-05" in model
|
||||
):
|
||||
# our account does not have access to this model
|
||||
continue
|
||||
|
@ -499,6 +500,7 @@ async def test_async_vertexai_streaming_response():
|
|||
or "ultra" in model
|
||||
or "002" in model
|
||||
or "gemini-2.0-flash-thinking-exp" in model
|
||||
or "gemini-2.0-pro-exp-02-05" in model
|
||||
):
|
||||
# our account does not have access to this model
|
||||
continue
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue