Merge pull request #5622 from BerriAI/litellm_fix_auth_refresh_vertex

[Feat-Perf Improvement Vertex] Only Refresh credentials when token is expired
This commit is contained in:
Ishaan Jaff 2024-09-10 15:03:35 -07:00 committed by GitHub
commit aed48e3bad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -9,7 +9,17 @@ import types
import uuid import uuid
from enum import Enum from enum import Enum
from functools import partial from functools import partial
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Literal,
Optional,
Tuple,
Union,
)
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore import requests # type: ignore
@ -68,6 +78,11 @@ from .transformation import (
sync_transform_request_body, sync_transform_request_body,
) )
if TYPE_CHECKING:
from google.auth.credentials import Credentials as GoogleCredentialsObject
else:
GoogleCredentialsObject = Any
class VertexAIConfig: class VertexAIConfig:
""" """
@ -811,7 +826,7 @@ class VertexLLM(BaseLLM):
super().__init__() super().__init__()
self.access_token: Optional[str] = None self.access_token: Optional[str] = None
self.refresh_token: Optional[str] = None self.refresh_token: Optional[str] = None
self._credentials: Optional[Any] = None self._credentials: Optional[GoogleCredentialsObject] = None
self.project_id: Optional[str] = None self.project_id: Optional[str] = None
self.async_handler: Optional[AsyncHTTPHandler] = None self.async_handler: Optional[AsyncHTTPHandler] = None
@ -1139,10 +1154,11 @@ class VertexLLM(BaseLLM):
if not self.project_id: if not self.project_id:
self.project_id = project_id or cred_project_id self.project_id = project_id or cred_project_id
else: else:
if self._credentials.expired or not self._credentials.token:
self.refresh_auth(self._credentials) self.refresh_auth(self._credentials)
if not self.project_id: if not self.project_id:
self.project_id = self._credentials.project_id self.project_id = self._credentials.quota_project_id
if not self.project_id: if not self.project_id:
raise ValueError("Could not resolve project_id") raise ValueError("Could not resolve project_id")