Merge pull request #5622 from BerriAI/litellm_fix_auth_refresh_vertex

[Feat-Perf Improvement Vertex] Only Refresh credentials when token is expired
This commit is contained in:
Ishaan Jaff 2024-09-10 15:03:35 -07:00 committed by GitHub
commit aed48e3bad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -9,7 +9,17 @@ import types
import uuid
from enum import Enum
from functools import partial
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Literal,
Optional,
Tuple,
Union,
)
import httpx # type: ignore
import requests # type: ignore
@ -68,6 +78,11 @@ from .transformation import (
sync_transform_request_body,
)
if TYPE_CHECKING:
from google.auth.credentials import Credentials as GoogleCredentialsObject
else:
GoogleCredentialsObject = Any
class VertexAIConfig:
"""
@ -811,7 +826,7 @@ class VertexLLM(BaseLLM):
super().__init__()
self.access_token: Optional[str] = None
self.refresh_token: Optional[str] = None
self._credentials: Optional[Any] = None
self._credentials: Optional[GoogleCredentialsObject] = None
self.project_id: Optional[str] = None
self.async_handler: Optional[AsyncHTTPHandler] = None
@ -1139,10 +1154,11 @@ class VertexLLM(BaseLLM):
if not self.project_id:
self.project_id = project_id or cred_project_id
else:
if self._credentials.expired or not self._credentials.token:
self.refresh_auth(self._credentials)
if not self.project_id:
self.project_id = self._credentials.project_id
self.project_id = self._credentials.quota_project_id
if not self.project_id:
raise ValueError("Could not resolve project_id")