mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Merge 0b70fa72af
into b82af5b826
This commit is contained in:
commit
c89dae22ec
1 changed files with 13 additions and 19 deletions
|
@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
|
|||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.asyncify import asyncify
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||
|
||||
from .common_utils import _get_gemini_url, _get_vertex_url, all_gemini_url_modes
|
||||
|
@ -24,15 +23,10 @@ else:
|
|||
class VertexBase:
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.access_token: Optional[str] = None
|
||||
self.refresh_token: Optional[str] = None
|
||||
self._credentials: Optional[GoogleCredentialsObject] = None
|
||||
self._credentials_project_mapping: Dict[
|
||||
Tuple[Optional[VERTEX_CREDENTIALS_TYPES], Optional[str]],
|
||||
GoogleCredentialsObject,
|
||||
] = {}
|
||||
self.project_id: Optional[str] = None
|
||||
self.async_handler: Optional[AsyncHTTPHandler] = None
|
||||
|
||||
def get_vertex_region(self, vertex_region: Optional[str]) -> str:
|
||||
return vertex_region or "us-central1"
|
||||
|
@ -261,7 +255,7 @@ class VertexBase:
|
|||
json.dumps(credentials) if isinstance(credentials, dict) else credentials
|
||||
)
|
||||
credential_cache_key = (cache_credentials, project_id)
|
||||
_credentials: Optional[GoogleCredentialsObject] = None
|
||||
cached_credentials: Optional[GoogleCredentialsObject] = None
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Checking cached credentials for project_id: {project_id}"
|
||||
|
@ -271,10 +265,10 @@ class VertexBase:
|
|||
verbose_logger.debug(
|
||||
f"Cached credentials found for project_id: {project_id}."
|
||||
)
|
||||
_credentials = self._credentials_project_mapping[credential_cache_key]
|
||||
cached_credentials = self._credentials_project_mapping[credential_cache_key]
|
||||
verbose_logger.debug("Using cached credentials")
|
||||
credential_project_id = _credentials.quota_project_id or getattr(
|
||||
_credentials, "project_id", None
|
||||
credential_project_id = cached_credentials.quota_project_id or getattr(
|
||||
cached_credentials, "project_id", None
|
||||
)
|
||||
|
||||
else:
|
||||
|
@ -283,7 +277,7 @@ class VertexBase:
|
|||
)
|
||||
|
||||
try:
|
||||
_credentials, credential_project_id = self.load_auth(
|
||||
cached_credentials, credential_project_id = self.load_auth(
|
||||
credentials=credentials, project_id=project_id
|
||||
)
|
||||
except Exception as e:
|
||||
|
@ -292,14 +286,14 @@ class VertexBase:
|
|||
)
|
||||
raise e
|
||||
|
||||
if _credentials is None:
|
||||
if cached_credentials is None:
|
||||
raise ValueError(
|
||||
"Could not resolve credentials - either dynamically or from environment, for project_id: {}".format(
|
||||
project_id
|
||||
)
|
||||
)
|
||||
|
||||
self._credentials_project_mapping[credential_cache_key] = _credentials
|
||||
self._credentials_project_mapping[credential_cache_key] = cached_credentials
|
||||
|
||||
## VALIDATE CREDENTIALS
|
||||
verbose_logger.debug(f"Validating credentials for project_id: {project_id}")
|
||||
|
@ -310,7 +304,7 @@ class VertexBase:
|
|||
):
|
||||
raise ValueError(
|
||||
"Could not resolve project_id. Credential project_id: {} does not match requested project_id: {}".format(
|
||||
_credentials.quota_project_id, project_id
|
||||
cached_credentials.quota_project_id, project_id
|
||||
)
|
||||
)
|
||||
elif (
|
||||
|
@ -320,21 +314,21 @@ class VertexBase:
|
|||
):
|
||||
project_id = credential_project_id
|
||||
|
||||
if _credentials.expired:
|
||||
self.refresh_auth(_credentials)
|
||||
if cached_credentials.expired:
|
||||
self.refresh_auth(cached_credentials)
|
||||
|
||||
## VALIDATION STEP
|
||||
if _credentials.token is None or not isinstance(_credentials.token, str):
|
||||
if cached_credentials.token is None or not isinstance(cached_credentials.token, str):
|
||||
raise ValueError(
|
||||
"Could not resolve credentials token. Got None or non-string token - {}".format(
|
||||
_credentials.token
|
||||
cached_credentials.token
|
||||
)
|
||||
)
|
||||
|
||||
if project_id is None:
|
||||
raise ValueError("Could not resolve project_id")
|
||||
|
||||
return _credentials.token, project_id
|
||||
return cached_credentials.token, project_id
|
||||
|
||||
async def _ensure_access_token_async(
|
||||
self,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue