This commit is contained in:
Yingchun Lai 2025-04-24 00:55:37 -07:00 committed by GitHub
commit c89dae22ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.litellm_core_utils.asyncify import asyncify from litellm.litellm_core_utils.asyncify import asyncify
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
from .common_utils import _get_gemini_url, _get_vertex_url, all_gemini_url_modes from .common_utils import _get_gemini_url, _get_vertex_url, all_gemini_url_modes
@ -24,15 +23,10 @@ else:
class VertexBase: class VertexBase:
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.access_token: Optional[str] = None
self.refresh_token: Optional[str] = None
self._credentials: Optional[GoogleCredentialsObject] = None
self._credentials_project_mapping: Dict[ self._credentials_project_mapping: Dict[
Tuple[Optional[VERTEX_CREDENTIALS_TYPES], Optional[str]], Tuple[Optional[VERTEX_CREDENTIALS_TYPES], Optional[str]],
GoogleCredentialsObject, GoogleCredentialsObject,
] = {} ] = {}
self.project_id: Optional[str] = None
self.async_handler: Optional[AsyncHTTPHandler] = None
def get_vertex_region(self, vertex_region: Optional[str]) -> str: def get_vertex_region(self, vertex_region: Optional[str]) -> str:
return vertex_region or "us-central1" return vertex_region or "us-central1"
@ -261,7 +255,7 @@ class VertexBase:
json.dumps(credentials) if isinstance(credentials, dict) else credentials json.dumps(credentials) if isinstance(credentials, dict) else credentials
) )
credential_cache_key = (cache_credentials, project_id) credential_cache_key = (cache_credentials, project_id)
_credentials: Optional[GoogleCredentialsObject] = None cached_credentials: Optional[GoogleCredentialsObject] = None
verbose_logger.debug( verbose_logger.debug(
f"Checking cached credentials for project_id: {project_id}" f"Checking cached credentials for project_id: {project_id}"
@ -271,10 +265,10 @@ class VertexBase:
verbose_logger.debug( verbose_logger.debug(
f"Cached credentials found for project_id: {project_id}." f"Cached credentials found for project_id: {project_id}."
) )
_credentials = self._credentials_project_mapping[credential_cache_key] cached_credentials = self._credentials_project_mapping[credential_cache_key]
verbose_logger.debug("Using cached credentials") verbose_logger.debug("Using cached credentials")
credential_project_id = _credentials.quota_project_id or getattr( credential_project_id = cached_credentials.quota_project_id or getattr(
_credentials, "project_id", None cached_credentials, "project_id", None
) )
else: else:
@ -283,7 +277,7 @@ class VertexBase:
) )
try: try:
_credentials, credential_project_id = self.load_auth( cached_credentials, credential_project_id = self.load_auth(
credentials=credentials, project_id=project_id credentials=credentials, project_id=project_id
) )
except Exception as e: except Exception as e:
@ -292,14 +286,14 @@ class VertexBase:
) )
raise e raise e
if _credentials is None: if cached_credentials is None:
raise ValueError( raise ValueError(
"Could not resolve credentials - either dynamically or from environment, for project_id: {}".format( "Could not resolve credentials - either dynamically or from environment, for project_id: {}".format(
project_id project_id
) )
) )
self._credentials_project_mapping[credential_cache_key] = _credentials self._credentials_project_mapping[credential_cache_key] = cached_credentials
## VALIDATE CREDENTIALS ## VALIDATE CREDENTIALS
verbose_logger.debug(f"Validating credentials for project_id: {project_id}") verbose_logger.debug(f"Validating credentials for project_id: {project_id}")
@ -310,7 +304,7 @@ class VertexBase:
): ):
raise ValueError( raise ValueError(
"Could not resolve project_id. Credential project_id: {} does not match requested project_id: {}".format( "Could not resolve project_id. Credential project_id: {} does not match requested project_id: {}".format(
_credentials.quota_project_id, project_id cached_credentials.quota_project_id, project_id
) )
) )
elif ( elif (
@ -320,21 +314,21 @@ class VertexBase:
): ):
project_id = credential_project_id project_id = credential_project_id
if _credentials.expired: if cached_credentials.expired:
self.refresh_auth(_credentials) self.refresh_auth(cached_credentials)
## VALIDATION STEP ## VALIDATION STEP
if _credentials.token is None or not isinstance(_credentials.token, str): if cached_credentials.token is None or not isinstance(cached_credentials.token, str):
raise ValueError( raise ValueError(
"Could not resolve credentials token. Got None or non-string token - {}".format( "Could not resolve credentials token. Got None or non-string token - {}".format(
_credentials.token cached_credentials.token
) )
) )
if project_id is None: if project_id is None:
raise ValueError("Could not resolve project_id") raise ValueError("Could not resolve project_id")
return _credentials.token, project_id return cached_credentials.token, project_id
async def _ensure_access_token_async( async def _ensure_access_token_async(
self, self,