diff --git a/litellm/integrations/gcs_bucket_base.py b/litellm/integrations/gcs_bucket_base.py index 073f0f265..15f5cbf68 100644 --- a/litellm/integrations/gcs_bucket_base.py +++ b/litellm/integrations/gcs_bucket_base.py @@ -34,10 +34,18 @@ class GCSBucketBase(CustomLogger): async def construct_request_headers(self) -> Dict[str, str]: from litellm import vertex_chat_completion + _auth_header, vertex_project = ( + await vertex_chat_completion._ensure_access_token_async( + credentials=self.path_service_account_json, + project_id=None, + ) + ) + auth_header, _ = vertex_chat_completion._get_token_and_url( model="gcs-bucket", + auth_header=_auth_header, vertex_credentials=self.path_service_account_json, - vertex_project=None, + vertex_project=vertex_project, vertex_location=None, gemini_api_key=None, stream=None, @@ -55,10 +63,16 @@ class GCSBucketBase(CustomLogger): def sync_construct_request_headers(self) -> Dict[str, str]: from litellm import vertex_chat_completion + _auth_header, vertex_project = vertex_chat_completion._ensure_access_token( + credentials=self.path_service_account_json, + project_id=None, + ) + auth_header, _ = vertex_chat_completion._get_token_and_url( model="gcs-bucket", + auth_header=_auth_header, vertex_credentials=self.path_service_account_json, - vertex_project=None, + vertex_project=vertex_project, vertex_location=None, gemini_api_key=None, stream=None, diff --git a/litellm/llms/fine_tuning_apis/vertex_ai.py b/litellm/llms/fine_tuning_apis/vertex_ai.py index 618cf510a..e24fd3894 100644 --- a/litellm/llms/fine_tuning_apis/vertex_ai.py +++ b/litellm/llms/fine_tuning_apis/vertex_ai.py @@ -185,8 +185,14 @@ class VertexFineTuningAPI(VertexLLM): "creating fine tuning job, args= %s", create_fine_tuning_job_data ) + _auth_header, vertex_project = self._ensure_access_token( + credentials=vertex_credentials, + project_id=vertex_project, + ) + auth_header, _ = self._get_token_and_url( model="", + auth_header=_auth_header, gemini_api_key=None, vertex_credentials=vertex_credentials, vertex_project=vertex_project, @@ -251,8 +257,14 @@ class VertexFineTuningAPI(VertexLLM): vertex_credentials: str, request_route: str, ): + + _auth_header, vertex_project = await self._ensure_access_token_async( + credentials=vertex_credentials, + project_id=vertex_project, + ) auth_header, _ = self._get_token_and_url( model="", + auth_header=_auth_header, gemini_api_key=None, vertex_credentials=vertex_credentials, vertex_project=vertex_project, diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py index 02907789a..8d69725dd 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py @@ -72,17 +72,13 @@ from ..common_utils import ( all_gemini_url_modes, get_supports_system_message, ) +from ..vertex_llm_base import VertexBase from .transformation import ( async_transform_request_body, set_headers, sync_transform_request_body, ) -if TYPE_CHECKING: - from google.auth.credentials import Credentials as GoogleCredentialsObject -else: - GoogleCredentialsObject = Any - class VertexAIConfig: """ @@ -821,14 +817,9 @@ def make_sync_call( return completion_stream -class VertexLLM(BaseLLM): +class VertexLLM(VertexBase): def __init__(self) -> None: super().__init__() - self.access_token: Optional[str] = None - self.refresh_token: Optional[str] = None - self._credentials: Optional[GoogleCredentialsObject] = None - self.project_id: Optional[str] = None - self.async_handler: Optional[AsyncHTTPHandler] = None def _process_response( self, @@ -1057,201 +1048,13 @@ class VertexLLM(BaseLLM): return model_response - def get_vertex_region(self, vertex_region: Optional[str]) -> str: - return vertex_region or "us-central1" - - def load_auth( - self, credentials: Optional[str], project_id: Optional[str] - ) -> Tuple[Any, str]: - import google.auth as google_auth - from google.auth import identity_pool - from google.auth.credentials import Credentials # type: ignore[import-untyped] - from google.auth.transport.requests import ( - Request, # type: ignore[import-untyped] - ) - - if credentials is not None and isinstance(credentials, str): - import google.oauth2.service_account - - verbose_logger.debug( - "Vertex: Loading vertex credentials from %s", credentials - ) - verbose_logger.debug( - "Vertex: checking if credentials is a valid path, os.path.exists(%s)=%s, current dir %s", - credentials, - os.path.exists(credentials), - os.getcwd(), - ) - - try: - if os.path.exists(credentials): - json_obj = json.load(open(credentials)) - else: - json_obj = json.loads(credentials) - except Exception: - raise Exception( - "Unable to load vertex credentials from environment. Got={}".format( - credentials - ) - ) - - # Check if the JSON object contains Workload Identity Federation configuration - if "type" in json_obj and json_obj["type"] == "external_account": - creds = identity_pool.Credentials.from_info(json_obj) - else: - creds = ( - google.oauth2.service_account.Credentials.from_service_account_info( - json_obj, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ) - ) - - if project_id is None: - project_id = creds.project_id - else: - creds, creds_project_id = google_auth.default( - quota_project_id=project_id, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ) - if project_id is None: - project_id = creds_project_id - - creds.refresh(Request()) - - if not project_id: - raise ValueError("Could not resolve project_id") - - if not isinstance(project_id, str): - raise TypeError( - f"Expected project_id to be a str but got {type(project_id)}" - ) - - return creds, project_id - - def refresh_auth(self, credentials: Any) -> None: - from google.auth.transport.requests import ( - Request, # type: ignore[import-untyped] - ) - - credentials.refresh(Request()) - - def _ensure_access_token( - self, credentials: Optional[str], project_id: Optional[str] - ) -> Tuple[str, str]: - """ - Returns auth token and project id - """ - if self.access_token is not None: - if project_id is not None: - return self.access_token, project_id - elif self.project_id is not None: - return self.access_token, self.project_id - - if not self._credentials: - self._credentials, cred_project_id = self.load_auth( - credentials=credentials, project_id=project_id - ) - if not self.project_id: - self.project_id = project_id or cred_project_id - else: - if self._credentials.expired or not self._credentials.token: - self.refresh_auth(self._credentials) - - if not self.project_id: - self.project_id = self._credentials.quota_project_id - - if not self.project_id: - raise ValueError("Could not resolve project_id") - - if not self._credentials or not self._credentials.token: - raise RuntimeError("Could not resolve API token from the environment") - - return self._credentials.token, project_id or self.project_id - - def is_using_v1beta1_features(self, optional_params: dict) -> bool: - """ - VertexAI only supports ContextCaching on v1beta1 - - use this helper to decide if request should be sent to v1 or v1beta1 - - Returns v1beta1 if context caching is enabled - Returns v1 in all other cases - """ - if "cached_content" in optional_params: - return True - if "CachedContent" in optional_params: - return True - return False - - def _get_token_and_url( - self, - model: str, - gemini_api_key: Optional[str], - vertex_project: Optional[str], - vertex_location: Optional[str], - vertex_credentials: Optional[str], - stream: Optional[bool], - custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"], - api_base: Optional[str], - should_use_v1beta1_features: Optional[bool] = False, - mode: all_gemini_url_modes = "chat", - ) -> Tuple[Optional[str], str]: - """ - Internal function. Returns the token and url for the call. - - Handles logic if it's google ai studio vs. vertex ai. - - Returns - token, url - """ - if custom_llm_provider == "gemini": - auth_header = None - url, endpoint = _get_gemini_url( - mode=mode, - model=model, - stream=stream, - gemini_api_key=gemini_api_key, - ) - else: - auth_header, vertex_project = self._ensure_access_token( - credentials=vertex_credentials, project_id=vertex_project - ) - vertex_location = self.get_vertex_region(vertex_region=vertex_location) - - ### SET RUNTIME ENDPOINT ### - version: Literal["v1beta1", "v1"] = ( - "v1beta1" if should_use_v1beta1_features is True else "v1" - ) - url, endpoint = _get_vertex_url( - mode=mode, - model=model, - stream=stream, - vertex_project=vertex_project, - vertex_location=vertex_location, - vertex_api_version=version, - ) - - if ( - api_base is not None - ): # for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317 - if custom_llm_provider == "gemini": - url = "{}:{}".format(api_base, endpoint) - auth_header = ( - gemini_api_key # cloudflare expects api key as bearer token - ) - else: - url = "{}:{}".format(api_base, endpoint) - - if stream is True: - url = url + "?alt=sse" - - return auth_header, url - async def async_streaming( self, model: str, + custom_llm_provider: Literal[ + "vertex_ai", "vertex_ai_beta", "gemini" + ], # if it's vertex_ai or gemini (google ai studio) messages: list, - api_base: str, model_response: ModelResponse, print_verbose: Callable, data: dict, @@ -1262,11 +1065,49 @@ class VertexLLM(BaseLLM): optional_params: dict, litellm_params=None, logger_fn=None, - headers={}, + api_base: Optional[str] = None, client: Optional[AsyncHTTPHandler] = None, + vertex_project: Optional[str] = None, + vertex_location: Optional[str] = None, + vertex_credentials: Optional[str] = None, + extra_headers: Optional[dict] = None, ) -> CustomStreamWrapper: request_body = await async_transform_request_body(**data) # type: ignore + should_use_v1beta1_features = self.is_using_v1beta1_features( + optional_params=optional_params + ) + + _auth_header, vertex_project = await self._ensure_access_token_async( + credentials=vertex_credentials, project_id=vertex_project + ) + + auth_header, api_base = self._get_token_and_url( + model=model, + gemini_api_key=None, + auth_header=_auth_header, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + stream=stream, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + should_use_v1beta1_features=should_use_v1beta1_features, + ) + + headers = set_headers(auth_header=auth_header, extra_headers=extra_headers) + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key="", + additional_args={ + "complete_input_dict": data, + "api_base": api_base, + "headers": headers, + }, + ) + request_body_str = json.dumps(request_body) streaming_response = CustomStreamWrapper( completion_stream=None, @@ -1290,21 +1131,50 @@ class VertexLLM(BaseLLM): self, model: str, messages: list, - api_base: str, model_response: ModelResponse, print_verbose: Callable, data: dict, + custom_llm_provider: Literal[ + "vertex_ai", "vertex_ai_beta", "gemini" + ], # if it's vertex_ai or gemini (google ai studio) timeout: Optional[Union[float, httpx.Timeout]], encoding, logging_obj, stream, optional_params: dict, litellm_params: dict, - headers: dict, logger_fn=None, + api_base: Optional[str] = None, client: Optional[AsyncHTTPHandler] = None, + vertex_project: Optional[str] = None, + vertex_location: Optional[str] = None, + vertex_credentials: Optional[str] = None, + extra_headers: Optional[dict] = None, ) -> Union[ModelResponse, CustomStreamWrapper]: + should_use_v1beta1_features = self.is_using_v1beta1_features( + optional_params=optional_params + ) + + _auth_header, vertex_project = await self._ensure_access_token_async( + credentials=vertex_credentials, project_id=vertex_project + ) + + auth_header, api_base = self._get_token_and_url( + model=model, + gemini_api_key=None, + auth_header=_auth_header, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + stream=stream, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + should_use_v1beta1_features=should_use_v1beta1_features, + ) + + headers = set_headers(auth_header=auth_header, extra_headers=extra_headers) + request_body = await async_transform_request_body(**data) # type: ignore _async_client_params = {} if timeout: @@ -1373,22 +1243,6 @@ class VertexLLM(BaseLLM): ) -> Union[ModelResponse, CustomStreamWrapper]: stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore - should_use_v1beta1_features = self.is_using_v1beta1_features( - optional_params=optional_params - ) - - auth_header, url = self._get_token_and_url( - model=model, - gemini_api_key=gemini_api_key, - vertex_project=vertex_project, - vertex_location=vertex_location, - vertex_credentials=vertex_credentials, - stream=stream, - custom_llm_provider=custom_llm_provider, - api_base=api_base, - should_use_v1beta1_features=should_use_v1beta1_features, - ) - transform_request_params = { "gemini_api_key": gemini_api_key, "messages": messages, @@ -1403,8 +1257,6 @@ class VertexLLM(BaseLLM): "litellm_params": litellm_params, } - headers = set_headers(auth_header=auth_header, extra_headers=extra_headers) - ### ROUTING (ASYNC, STREAMING, SYNC) if acompletion: ### ASYNC STREAMING @@ -1412,7 +1264,7 @@ class VertexLLM(BaseLLM): return self.async_streaming( model=model, messages=messages, - api_base=url, + api_base=api_base, model_response=model_response, print_verbose=print_verbose, encoding=encoding, @@ -1424,14 +1276,18 @@ class VertexLLM(BaseLLM): timeout=timeout, client=client, # type: ignore data=transform_request_params, - headers=headers, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + custom_llm_provider=custom_llm_provider, + extra_headers=extra_headers, ) ### ASYNC COMPLETION return self.async_completion( model=model, messages=messages, data=transform_request_params, # type: ignore - api_base=url, + api_base=api_base, model_response=model_response, print_verbose=print_verbose, encoding=encoding, @@ -1442,10 +1298,35 @@ class VertexLLM(BaseLLM): logger_fn=logger_fn, timeout=timeout, client=client, # type: ignore - headers=headers, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + custom_llm_provider=custom_llm_provider, + extra_headers=extra_headers, ) - ## SYNC STREAMING CALL ## + should_use_v1beta1_features = self.is_using_v1beta1_features( + optional_params=optional_params + ) + + _auth_header, vertex_project = self._ensure_access_token( + credentials=vertex_credentials, project_id=vertex_project + ) + + auth_header, url = self._get_token_and_url( + model=model, + gemini_api_key=gemini_api_key, + auth_header=_auth_header, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + stream=stream, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + should_use_v1beta1_features=should_use_v1beta1_features, + ) + headers = set_headers(auth_header=auth_header, extra_headers=extra_headers) + ## TRANSFORMATION ## data = sync_transform_request_body(**transform_request_params) @@ -1460,6 +1341,7 @@ class VertexLLM(BaseLLM): }, ) + ## SYNC STREAMING CALL ## if stream is True: request_data_str = json.dumps(data) streaming_response = CustomStreamWrapper( diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini_embeddings/batch_embed_content_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini_embeddings/batch_embed_content_handler.py index d05688dee..d5d3cf1ec 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini_embeddings/batch_embed_content_handler.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini_embeddings/batch_embed_content_handler.py @@ -43,8 +43,14 @@ class GoogleBatchEmbeddings(VertexLLM): client=None, ) -> EmbeddingResponse: + _auth_header, vertex_project = self._ensure_access_token( + credentials=vertex_credentials, + project_id=vertex_project, + ) + auth_header, url = self._get_token_and_url( model=model, + auth_header=_auth_header, gemini_api_key=api_key, vertex_project=vertex_project, vertex_location=vertex_location, diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/multimodal_embeddings/embedding_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/multimodal_embeddings/embedding_handler.py index 455358191..aa8c2123a 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/multimodal_embeddings/embedding_handler.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/multimodal_embeddings/embedding_handler.py @@ -43,8 +43,15 @@ class VertexMultimodalEmbedding(VertexLLM): timeout=300, client=None, ): + + _auth_header, vertex_project = self._ensure_access_token( + credentials=vertex_credentials, + project_id=vertex_project, + ) + auth_header, url = self._get_token_and_url( model=model, + auth_header=_auth_header, gemini_api_key=api_key, vertex_project=vertex_project, vertex_location=vertex_location, diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/text_to_speech/text_to_speech_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/text_to_speech/text_to_speech_handler.py index bc2424ecc..8818d13bc 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/text_to_speech/text_to_speech_handler.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/text_to_speech/text_to_speech_handler.py @@ -65,8 +65,15 @@ class VertexTextToSpeechAPI(VertexLLM): import base64 ####### Authenticate with Vertex AI ######## + + _auth_header, vertex_project = self._ensure_access_token( + credentials=vertex_credentials, + project_id=vertex_project, + ) + auth_header, _ = self._get_token_and_url( model="", + auth_header=_auth_header, gemini_api_key=None, vertex_credentials=vertex_credentials, vertex_project=vertex_project, diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_llm_base.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_llm_base.py new file mode 100644 index 000000000..3c817e9ea --- /dev/null +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_llm_base.py @@ -0,0 +1,255 @@ +import json +import os +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple + +from litellm._logging import verbose_logger +from litellm.litellm_core_utils.asyncify import asyncify +from litellm.llms.base import BaseLLM +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler + +from .common_utils import ( + VertexAIError, + _get_gemini_url, + _get_vertex_url, + all_gemini_url_modes, + get_supports_system_message, +) + +if TYPE_CHECKING: + from google.auth.credentials import Credentials as GoogleCredentialsObject +else: + GoogleCredentialsObject = Any + + +class VertexBase(BaseLLM): + def __init__(self) -> None: + super().__init__() + self.access_token: Optional[str] = None + self.refresh_token: Optional[str] = None + self._credentials: Optional[GoogleCredentialsObject] = None + self.project_id: Optional[str] = None + self.async_handler: Optional[AsyncHTTPHandler] = None + + def get_vertex_region(self, vertex_region: Optional[str]) -> str: + return vertex_region or "us-central1" + + def load_auth( + self, credentials: Optional[str], project_id: Optional[str] + ) -> Tuple[Any, str]: + import google.auth as google_auth + from google.auth import identity_pool + from google.auth.credentials import Credentials # type: ignore[import-untyped] + from google.auth.transport.requests import ( + Request, # type: ignore[import-untyped] + ) + + if credentials is not None and isinstance(credentials, str): + import google.oauth2.service_account + + verbose_logger.debug( + "Vertex: Loading vertex credentials from %s", credentials + ) + verbose_logger.debug( + "Vertex: checking if credentials is a valid path, os.path.exists(%s)=%s, current dir %s", + credentials, + os.path.exists(credentials), + os.getcwd(), + ) + + try: + if os.path.exists(credentials): + json_obj = json.load(open(credentials)) + else: + json_obj = json.loads(credentials) + except Exception: + raise Exception( + "Unable to load vertex credentials from environment. Got={}".format( + credentials + ) + ) + + # Check if the JSON object contains Workload Identity Federation configuration + if "type" in json_obj and json_obj["type"] == "external_account": + creds = identity_pool.Credentials.from_info(json_obj) + else: + creds = ( + google.oauth2.service_account.Credentials.from_service_account_info( + json_obj, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + ) + + if project_id is None: + project_id = creds.project_id + else: + creds, creds_project_id = google_auth.default( + quota_project_id=project_id, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + if project_id is None: + project_id = creds_project_id + + creds.refresh(Request()) + + if not project_id: + raise ValueError("Could not resolve project_id") + + if not isinstance(project_id, str): + raise TypeError( + f"Expected project_id to be a str but got {type(project_id)}" + ) + + return creds, project_id + + def refresh_auth(self, credentials: Any) -> None: + from google.auth.transport.requests import ( + Request, # type: ignore[import-untyped] + ) + + credentials.refresh(Request()) + + def _ensure_access_token( + self, credentials: Optional[str], project_id: Optional[str] + ) -> Tuple[str, str]: + """ + Returns auth token and project id + """ + if self.access_token is not None: + if project_id is not None: + return self.access_token, project_id + elif self.project_id is not None: + return self.access_token, self.project_id + + if not self._credentials: + self._credentials, cred_project_id = self.load_auth( + credentials=credentials, project_id=project_id + ) + if not self.project_id: + self.project_id = project_id or cred_project_id + else: + if self._credentials.expired or not self._credentials.token: + self.refresh_auth(self._credentials) + + if not self.project_id: + self.project_id = self._credentials.quota_project_id + + if not self.project_id: + raise ValueError("Could not resolve project_id") + + if not self._credentials or not self._credentials.token: + raise RuntimeError("Could not resolve API token from the environment") + + return self._credentials.token, project_id or self.project_id + + def is_using_v1beta1_features(self, optional_params: dict) -> bool: + """ + VertexAI only supports ContextCaching on v1beta1 + + use this helper to decide if request should be sent to v1 or v1beta1 + + Returns v1beta1 if context caching is enabled + Returns v1 in all other cases + """ + if "cached_content" in optional_params: + return True + if "CachedContent" in optional_params: + return True + return False + + def _get_token_and_url( + self, + model: str, + auth_header: str, + gemini_api_key: Optional[str], + vertex_project: Optional[str], + vertex_location: Optional[str], + vertex_credentials: Optional[str], + stream: Optional[bool], + custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"], + api_base: Optional[str], + should_use_v1beta1_features: Optional[bool] = False, + mode: all_gemini_url_modes = "chat", + ) -> Tuple[Optional[str], str]: + """ + Internal function. Returns the token and url for the call. + + Handles logic if it's google ai studio vs. vertex ai. + + Returns + token, url + """ + if custom_llm_provider == "gemini": + url, endpoint = _get_gemini_url( + mode=mode, + model=model, + stream=stream, + gemini_api_key=gemini_api_key, + ) + else: + vertex_location = self.get_vertex_region(vertex_region=vertex_location) + + ### SET RUNTIME ENDPOINT ### + version: Literal["v1beta1", "v1"] = ( + "v1beta1" if should_use_v1beta1_features is True else "v1" + ) + url, endpoint = _get_vertex_url( + mode=mode, + model=model, + stream=stream, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_api_version=version, + ) + + if ( + api_base is not None + ): # for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317 + if custom_llm_provider == "gemini": + url = "{}:{}".format(api_base, endpoint) + if gemini_api_key is None: + raise ValueError( + "Missing gemini_api_key, please set `GEMINI_API_KEY`" + ) + auth_header = ( + gemini_api_key # cloudflare expects api key as bearer token + ) + else: + url = "{}:{}".format(api_base, endpoint) + + if stream is True: + url = url + "?alt=sse" + + return auth_header, url + + async def _ensure_access_token_async( + self, credentials: Optional[str], project_id: Optional[str] + ) -> Tuple[str, str]: + """ + Async version of _ensure_access_token + """ + if self.access_token is not None: + if project_id is not None: + return self.access_token, project_id + elif self.project_id is not None: + return self.access_token, self.project_id + + if not self._credentials: + self._credentials, cred_project_id = await asyncify(self.load_auth)( + credentials=credentials, project_id=project_id + ) + if not self.project_id: + self.project_id = project_id or cred_project_id + else: + if self._credentials.expired or not self._credentials.token: + await asyncify(self.refresh_auth)(self._credentials) + + if not self.project_id: + self.project_id = self._credentials.quota_project_id + + if not self.project_id: + raise ValueError("Could not resolve project_id") + + if not self._credentials or not self._credentials.token: + raise RuntimeError("Could not resolve API token from the environment") + + return self._credentials.token, project_id or self.project_id diff --git a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py index 4e33ee497..68d4146f2 100644 --- a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py @@ -150,8 +150,15 @@ async def vertex_proxy_route( base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/" + _auth_header, vertex_project = ( + await vertex_fine_tuning_apis_instance._ensure_access_token_async( + credentials=vertex_credentials, project_id=vertex_project + ) + ) + auth_header, _ = vertex_fine_tuning_apis_instance._get_token_and_url( model="", + auth_header=_auth_header, gemini_api_key=None, vertex_credentials=vertex_credentials, vertex_project=vertex_project,