diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py index 9b57c293e..bae023ab5 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py @@ -9,7 +9,17 @@ import types import uuid from enum import Enum from functools import partial -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Literal, + Optional, + Tuple, + Union, +) import httpx # type: ignore import requests # type: ignore @@ -68,6 +78,11 @@ from .transformation import ( sync_transform_request_body, ) +if TYPE_CHECKING: + from google.auth.credentials import Credentials as GoogleCredentialsObject +else: + GoogleCredentialsObject = Any + class VertexAIConfig: """ @@ -811,7 +826,7 @@ class VertexLLM(BaseLLM): super().__init__() self.access_token: Optional[str] = None self.refresh_token: Optional[str] = None - self._credentials: Optional[Any] = None + self._credentials: Optional[GoogleCredentialsObject] = None self.project_id: Optional[str] = None self.async_handler: Optional[AsyncHTTPHandler] = None @@ -1139,10 +1154,11 @@ class VertexLLM(BaseLLM): if not self.project_id: self.project_id = project_id or cred_project_id else: - self.refresh_auth(self._credentials) + if self._credentials.expired or not self._credentials.token: + self.refresh_auth(self._credentials) if not self.project_id: - self.project_id = self._credentials.project_id + self.project_id = self._credentials.quota_project_id if not self.project_id: raise ValueError("Could not resolve project_id")