mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Merge 0b70fa72af
into b82af5b826
This commit is contained in:
commit
c89dae22ec
1 changed files with 13 additions and 19 deletions
|
@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
|
||||||
|
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.litellm_core_utils.asyncify import asyncify
|
from litellm.litellm_core_utils.asyncify import asyncify
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
|
||||||
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||||
|
|
||||||
from .common_utils import _get_gemini_url, _get_vertex_url, all_gemini_url_modes
|
from .common_utils import _get_gemini_url, _get_vertex_url, all_gemini_url_modes
|
||||||
|
@ -24,15 +23,10 @@ else:
|
||||||
class VertexBase:
|
class VertexBase:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.access_token: Optional[str] = None
|
|
||||||
self.refresh_token: Optional[str] = None
|
|
||||||
self._credentials: Optional[GoogleCredentialsObject] = None
|
|
||||||
self._credentials_project_mapping: Dict[
|
self._credentials_project_mapping: Dict[
|
||||||
Tuple[Optional[VERTEX_CREDENTIALS_TYPES], Optional[str]],
|
Tuple[Optional[VERTEX_CREDENTIALS_TYPES], Optional[str]],
|
||||||
GoogleCredentialsObject,
|
GoogleCredentialsObject,
|
||||||
] = {}
|
] = {}
|
||||||
self.project_id: Optional[str] = None
|
|
||||||
self.async_handler: Optional[AsyncHTTPHandler] = None
|
|
||||||
|
|
||||||
def get_vertex_region(self, vertex_region: Optional[str]) -> str:
|
def get_vertex_region(self, vertex_region: Optional[str]) -> str:
|
||||||
return vertex_region or "us-central1"
|
return vertex_region or "us-central1"
|
||||||
|
@ -261,7 +255,7 @@ class VertexBase:
|
||||||
json.dumps(credentials) if isinstance(credentials, dict) else credentials
|
json.dumps(credentials) if isinstance(credentials, dict) else credentials
|
||||||
)
|
)
|
||||||
credential_cache_key = (cache_credentials, project_id)
|
credential_cache_key = (cache_credentials, project_id)
|
||||||
_credentials: Optional[GoogleCredentialsObject] = None
|
cached_credentials: Optional[GoogleCredentialsObject] = None
|
||||||
|
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"Checking cached credentials for project_id: {project_id}"
|
f"Checking cached credentials for project_id: {project_id}"
|
||||||
|
@ -271,10 +265,10 @@ class VertexBase:
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"Cached credentials found for project_id: {project_id}."
|
f"Cached credentials found for project_id: {project_id}."
|
||||||
)
|
)
|
||||||
_credentials = self._credentials_project_mapping[credential_cache_key]
|
cached_credentials = self._credentials_project_mapping[credential_cache_key]
|
||||||
verbose_logger.debug("Using cached credentials")
|
verbose_logger.debug("Using cached credentials")
|
||||||
credential_project_id = _credentials.quota_project_id or getattr(
|
credential_project_id = cached_credentials.quota_project_id or getattr(
|
||||||
_credentials, "project_id", None
|
cached_credentials, "project_id", None
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -283,7 +277,7 @@ class VertexBase:
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_credentials, credential_project_id = self.load_auth(
|
cached_credentials, credential_project_id = self.load_auth(
|
||||||
credentials=credentials, project_id=project_id
|
credentials=credentials, project_id=project_id
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -292,14 +286,14 @@ class VertexBase:
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
if _credentials is None:
|
if cached_credentials is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Could not resolve credentials - either dynamically or from environment, for project_id: {}".format(
|
"Could not resolve credentials - either dynamically or from environment, for project_id: {}".format(
|
||||||
project_id
|
project_id
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self._credentials_project_mapping[credential_cache_key] = _credentials
|
self._credentials_project_mapping[credential_cache_key] = cached_credentials
|
||||||
|
|
||||||
## VALIDATE CREDENTIALS
|
## VALIDATE CREDENTIALS
|
||||||
verbose_logger.debug(f"Validating credentials for project_id: {project_id}")
|
verbose_logger.debug(f"Validating credentials for project_id: {project_id}")
|
||||||
|
@ -310,7 +304,7 @@ class VertexBase:
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Could not resolve project_id. Credential project_id: {} does not match requested project_id: {}".format(
|
"Could not resolve project_id. Credential project_id: {} does not match requested project_id: {}".format(
|
||||||
_credentials.quota_project_id, project_id
|
cached_credentials.quota_project_id, project_id
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
|
@ -320,21 +314,21 @@ class VertexBase:
|
||||||
):
|
):
|
||||||
project_id = credential_project_id
|
project_id = credential_project_id
|
||||||
|
|
||||||
if _credentials.expired:
|
if cached_credentials.expired:
|
||||||
self.refresh_auth(_credentials)
|
self.refresh_auth(cached_credentials)
|
||||||
|
|
||||||
## VALIDATION STEP
|
## VALIDATION STEP
|
||||||
if _credentials.token is None or not isinstance(_credentials.token, str):
|
if cached_credentials.token is None or not isinstance(cached_credentials.token, str):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Could not resolve credentials token. Got None or non-string token - {}".format(
|
"Could not resolve credentials token. Got None or non-string token - {}".format(
|
||||||
_credentials.token
|
cached_credentials.token
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if project_id is None:
|
if project_id is None:
|
||||||
raise ValueError("Could not resolve project_id")
|
raise ValueError("Could not resolve project_id")
|
||||||
|
|
||||||
return _credentials.token, project_id
|
return cached_credentials.token, project_id
|
||||||
|
|
||||||
async def _ensure_access_token_async(
|
async def _ensure_access_token_async(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue