This commit is contained in:
Yingchun Lai 2025-04-24 00:55:37 -07:00 committed by GitHub
commit c89dae22ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.asyncify import asyncify
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
from .common_utils import _get_gemini_url, _get_vertex_url, all_gemini_url_modes
@ -24,15 +23,10 @@ else:
class VertexBase:
def __init__(self) -> None:
super().__init__()
self.access_token: Optional[str] = None
self.refresh_token: Optional[str] = None
self._credentials: Optional[GoogleCredentialsObject] = None
self._credentials_project_mapping: Dict[
Tuple[Optional[VERTEX_CREDENTIALS_TYPES], Optional[str]],
GoogleCredentialsObject,
] = {}
self.project_id: Optional[str] = None
self.async_handler: Optional[AsyncHTTPHandler] = None
def get_vertex_region(self, vertex_region: Optional[str]) -> str:
return vertex_region or "us-central1"
@ -261,7 +255,7 @@ class VertexBase:
json.dumps(credentials) if isinstance(credentials, dict) else credentials
)
credential_cache_key = (cache_credentials, project_id)
_credentials: Optional[GoogleCredentialsObject] = None
cached_credentials: Optional[GoogleCredentialsObject] = None
verbose_logger.debug(
f"Checking cached credentials for project_id: {project_id}"
@ -271,10 +265,10 @@ class VertexBase:
verbose_logger.debug(
f"Cached credentials found for project_id: {project_id}."
)
_credentials = self._credentials_project_mapping[credential_cache_key]
cached_credentials = self._credentials_project_mapping[credential_cache_key]
verbose_logger.debug("Using cached credentials")
credential_project_id = _credentials.quota_project_id or getattr(
_credentials, "project_id", None
credential_project_id = cached_credentials.quota_project_id or getattr(
cached_credentials, "project_id", None
)
else:
@ -283,7 +277,7 @@ class VertexBase:
)
try:
_credentials, credential_project_id = self.load_auth(
cached_credentials, credential_project_id = self.load_auth(
credentials=credentials, project_id=project_id
)
except Exception as e:
@ -292,14 +286,14 @@ class VertexBase:
)
raise e
if _credentials is None:
if cached_credentials is None:
raise ValueError(
"Could not resolve credentials - either dynamically or from environment, for project_id: {}".format(
project_id
)
)
self._credentials_project_mapping[credential_cache_key] = _credentials
self._credentials_project_mapping[credential_cache_key] = cached_credentials
## VALIDATE CREDENTIALS
verbose_logger.debug(f"Validating credentials for project_id: {project_id}")
@ -310,7 +304,7 @@ class VertexBase:
):
raise ValueError(
"Could not resolve project_id. Credential project_id: {} does not match requested project_id: {}".format(
_credentials.quota_project_id, project_id
cached_credentials.quota_project_id, project_id
)
)
elif (
@ -320,21 +314,21 @@ class VertexBase:
):
project_id = credential_project_id
if _credentials.expired:
self.refresh_auth(_credentials)
if cached_credentials.expired:
self.refresh_auth(cached_credentials)
## VALIDATION STEP
if _credentials.token is None or not isinstance(_credentials.token, str):
if cached_credentials.token is None or not isinstance(cached_credentials.token, str):
raise ValueError(
"Could not resolve credentials token. Got None or non-string token - {}".format(
_credentials.token
cached_credentials.token
)
)
if project_id is None:
raise ValueError("Could not resolve project_id")
return _credentials.token, project_id
return cached_credentials.token, project_id
async def _ensure_access_token_async(
self,