forked from phoenix/litellm-mirror
fix vertex use async func to set auth creds
This commit is contained in:
parent
26ae86e59b
commit
1c6f8b1be2
8 changed files with 420 additions and 230 deletions
|
@ -34,10 +34,18 @@ class GCSBucketBase(CustomLogger):
|
|||
async def construct_request_headers(self) -> Dict[str, str]:
|
||||
from litellm import vertex_chat_completion
|
||||
|
||||
_auth_header, vertex_project = (
|
||||
await vertex_chat_completion._ensure_access_token_async(
|
||||
credentials=self.path_service_account_json,
|
||||
project_id=None,
|
||||
)
|
||||
)
|
||||
|
||||
auth_header, _ = vertex_chat_completion._get_token_and_url(
|
||||
model="gcs-bucket",
|
||||
auth_header=_auth_header,
|
||||
vertex_credentials=self.path_service_account_json,
|
||||
vertex_project=None,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=None,
|
||||
gemini_api_key=None,
|
||||
stream=None,
|
||||
|
@ -55,10 +63,16 @@ class GCSBucketBase(CustomLogger):
|
|||
def sync_construct_request_headers(self) -> Dict[str, str]:
|
||||
from litellm import vertex_chat_completion
|
||||
|
||||
_auth_header, vertex_project = vertex_chat_completion._ensure_access_token(
|
||||
credentials=self.path_service_account_json,
|
||||
project_id=None,
|
||||
)
|
||||
|
||||
auth_header, _ = vertex_chat_completion._get_token_and_url(
|
||||
model="gcs-bucket",
|
||||
auth_header=_auth_header,
|
||||
vertex_credentials=self.path_service_account_json,
|
||||
vertex_project=None,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=None,
|
||||
gemini_api_key=None,
|
||||
stream=None,
|
||||
|
|
|
@ -185,8 +185,14 @@ class VertexFineTuningAPI(VertexLLM):
|
|||
"creating fine tuning job, args= %s", create_fine_tuning_job_data
|
||||
)
|
||||
|
||||
_auth_header, vertex_project = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
)
|
||||
|
||||
auth_header, _ = self._get_token_and_url(
|
||||
model="",
|
||||
auth_header=_auth_header,
|
||||
gemini_api_key=None,
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_project=vertex_project,
|
||||
|
@ -251,8 +257,14 @@ class VertexFineTuningAPI(VertexLLM):
|
|||
vertex_credentials: str,
|
||||
request_route: str,
|
||||
):
|
||||
|
||||
_auth_header, vertex_project = await self._ensure_access_token_async(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
)
|
||||
auth_header, _ = self._get_token_and_url(
|
||||
model="",
|
||||
auth_header=_auth_header,
|
||||
gemini_api_key=None,
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_project=vertex_project,
|
||||
|
|
|
@ -72,17 +72,13 @@ from ..common_utils import (
|
|||
all_gemini_url_modes,
|
||||
get_supports_system_message,
|
||||
)
|
||||
from ..vertex_llm_base import VertexBase
|
||||
from .transformation import (
|
||||
async_transform_request_body,
|
||||
set_headers,
|
||||
sync_transform_request_body,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google.auth.credentials import Credentials as GoogleCredentialsObject
|
||||
else:
|
||||
GoogleCredentialsObject = Any
|
||||
|
||||
|
||||
class VertexAIConfig:
|
||||
"""
|
||||
|
@ -821,14 +817,9 @@ def make_sync_call(
|
|||
return completion_stream
|
||||
|
||||
|
||||
class VertexLLM(BaseLLM):
|
||||
class VertexLLM(VertexBase):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.access_token: Optional[str] = None
|
||||
self.refresh_token: Optional[str] = None
|
||||
self._credentials: Optional[GoogleCredentialsObject] = None
|
||||
self.project_id: Optional[str] = None
|
||||
self.async_handler: Optional[AsyncHTTPHandler] = None
|
||||
|
||||
def _process_response(
|
||||
self,
|
||||
|
@ -1057,201 +1048,13 @@ class VertexLLM(BaseLLM):
|
|||
|
||||
return model_response
|
||||
|
||||
def get_vertex_region(self, vertex_region: Optional[str]) -> str:
|
||||
return vertex_region or "us-central1"
|
||||
|
||||
def load_auth(
|
||||
self, credentials: Optional[str], project_id: Optional[str]
|
||||
) -> Tuple[Any, str]:
|
||||
import google.auth as google_auth
|
||||
from google.auth import identity_pool
|
||||
from google.auth.credentials import Credentials # type: ignore[import-untyped]
|
||||
from google.auth.transport.requests import (
|
||||
Request, # type: ignore[import-untyped]
|
||||
)
|
||||
|
||||
if credentials is not None and isinstance(credentials, str):
|
||||
import google.oauth2.service_account
|
||||
|
||||
verbose_logger.debug(
|
||||
"Vertex: Loading vertex credentials from %s", credentials
|
||||
)
|
||||
verbose_logger.debug(
|
||||
"Vertex: checking if credentials is a valid path, os.path.exists(%s)=%s, current dir %s",
|
||||
credentials,
|
||||
os.path.exists(credentials),
|
||||
os.getcwd(),
|
||||
)
|
||||
|
||||
try:
|
||||
if os.path.exists(credentials):
|
||||
json_obj = json.load(open(credentials))
|
||||
else:
|
||||
json_obj = json.loads(credentials)
|
||||
except Exception:
|
||||
raise Exception(
|
||||
"Unable to load vertex credentials from environment. Got={}".format(
|
||||
credentials
|
||||
)
|
||||
)
|
||||
|
||||
# Check if the JSON object contains Workload Identity Federation configuration
|
||||
if "type" in json_obj and json_obj["type"] == "external_account":
|
||||
creds = identity_pool.Credentials.from_info(json_obj)
|
||||
else:
|
||||
creds = (
|
||||
google.oauth2.service_account.Credentials.from_service_account_info(
|
||||
json_obj,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
)
|
||||
|
||||
if project_id is None:
|
||||
project_id = creds.project_id
|
||||
else:
|
||||
creds, creds_project_id = google_auth.default(
|
||||
quota_project_id=project_id,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
if project_id is None:
|
||||
project_id = creds_project_id
|
||||
|
||||
creds.refresh(Request())
|
||||
|
||||
if not project_id:
|
||||
raise ValueError("Could not resolve project_id")
|
||||
|
||||
if not isinstance(project_id, str):
|
||||
raise TypeError(
|
||||
f"Expected project_id to be a str but got {type(project_id)}"
|
||||
)
|
||||
|
||||
return creds, project_id
|
||||
|
||||
def refresh_auth(self, credentials: Any) -> None:
|
||||
from google.auth.transport.requests import (
|
||||
Request, # type: ignore[import-untyped]
|
||||
)
|
||||
|
||||
credentials.refresh(Request())
|
||||
|
||||
def _ensure_access_token(
|
||||
self, credentials: Optional[str], project_id: Optional[str]
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Returns auth token and project id
|
||||
"""
|
||||
if self.access_token is not None:
|
||||
if project_id is not None:
|
||||
return self.access_token, project_id
|
||||
elif self.project_id is not None:
|
||||
return self.access_token, self.project_id
|
||||
|
||||
if not self._credentials:
|
||||
self._credentials, cred_project_id = self.load_auth(
|
||||
credentials=credentials, project_id=project_id
|
||||
)
|
||||
if not self.project_id:
|
||||
self.project_id = project_id or cred_project_id
|
||||
else:
|
||||
if self._credentials.expired or not self._credentials.token:
|
||||
self.refresh_auth(self._credentials)
|
||||
|
||||
if not self.project_id:
|
||||
self.project_id = self._credentials.quota_project_id
|
||||
|
||||
if not self.project_id:
|
||||
raise ValueError("Could not resolve project_id")
|
||||
|
||||
if not self._credentials or not self._credentials.token:
|
||||
raise RuntimeError("Could not resolve API token from the environment")
|
||||
|
||||
return self._credentials.token, project_id or self.project_id
|
||||
|
||||
def is_using_v1beta1_features(self, optional_params: dict) -> bool:
|
||||
"""
|
||||
VertexAI only supports ContextCaching on v1beta1
|
||||
|
||||
use this helper to decide if request should be sent to v1 or v1beta1
|
||||
|
||||
Returns v1beta1 if context caching is enabled
|
||||
Returns v1 in all other cases
|
||||
"""
|
||||
if "cached_content" in optional_params:
|
||||
return True
|
||||
if "CachedContent" in optional_params:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_token_and_url(
|
||||
self,
|
||||
model: str,
|
||||
gemini_api_key: Optional[str],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_credentials: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||
api_base: Optional[str],
|
||||
should_use_v1beta1_features: Optional[bool] = False,
|
||||
mode: all_gemini_url_modes = "chat",
|
||||
) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
Internal function. Returns the token and url for the call.
|
||||
|
||||
Handles logic if it's google ai studio vs. vertex ai.
|
||||
|
||||
Returns
|
||||
token, url
|
||||
"""
|
||||
if custom_llm_provider == "gemini":
|
||||
auth_header = None
|
||||
url, endpoint = _get_gemini_url(
|
||||
mode=mode,
|
||||
model=model,
|
||||
stream=stream,
|
||||
gemini_api_key=gemini_api_key,
|
||||
)
|
||||
else:
|
||||
auth_header, vertex_project = self._ensure_access_token(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
)
|
||||
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
version: Literal["v1beta1", "v1"] = (
|
||||
"v1beta1" if should_use_v1beta1_features is True else "v1"
|
||||
)
|
||||
url, endpoint = _get_vertex_url(
|
||||
mode=mode,
|
||||
model=model,
|
||||
stream=stream,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_api_version=version,
|
||||
)
|
||||
|
||||
if (
|
||||
api_base is not None
|
||||
): # for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317
|
||||
if custom_llm_provider == "gemini":
|
||||
url = "{}:{}".format(api_base, endpoint)
|
||||
auth_header = (
|
||||
gemini_api_key # cloudflare expects api key as bearer token
|
||||
)
|
||||
else:
|
||||
url = "{}:{}".format(api_base, endpoint)
|
||||
|
||||
if stream is True:
|
||||
url = url + "?alt=sse"
|
||||
|
||||
return auth_header, url
|
||||
|
||||
async def async_streaming(
|
||||
self,
|
||||
model: str,
|
||||
custom_llm_provider: Literal[
|
||||
"vertex_ai", "vertex_ai_beta", "gemini"
|
||||
], # if it's vertex_ai or gemini (google ai studio)
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
data: dict,
|
||||
|
@ -1262,11 +1065,49 @@ class VertexLLM(BaseLLM):
|
|||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
api_base: Optional[str] = None,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
vertex_project: Optional[str] = None,
|
||||
vertex_location: Optional[str] = None,
|
||||
vertex_credentials: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
request_body = await async_transform_request_body(**data) # type: ignore
|
||||
|
||||
should_use_v1beta1_features = self.is_using_v1beta1_features(
|
||||
optional_params=optional_params
|
||||
)
|
||||
|
||||
_auth_header, vertex_project = await self._ensure_access_token_async(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
)
|
||||
|
||||
auth_header, api_base = self._get_token_and_url(
|
||||
model=model,
|
||||
gemini_api_key=None,
|
||||
auth_header=_auth_header,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
stream=stream,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
should_use_v1beta1_features=should_use_v1beta1_features,
|
||||
)
|
||||
|
||||
headers = set_headers(auth_header=auth_header, extra_headers=extra_headers)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
request_body_str = json.dumps(request_body)
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=None,
|
||||
|
@ -1290,21 +1131,50 @@ class VertexLLM(BaseLLM):
|
|||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
data: dict,
|
||||
custom_llm_provider: Literal[
|
||||
"vertex_ai", "vertex_ai_beta", "gemini"
|
||||
], # if it's vertex_ai or gemini (google ai studio)
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
encoding,
|
||||
logging_obj,
|
||||
stream,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
logger_fn=None,
|
||||
api_base: Optional[str] = None,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
vertex_project: Optional[str] = None,
|
||||
vertex_location: Optional[str] = None,
|
||||
vertex_credentials: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
|
||||
should_use_v1beta1_features = self.is_using_v1beta1_features(
|
||||
optional_params=optional_params
|
||||
)
|
||||
|
||||
_auth_header, vertex_project = await self._ensure_access_token_async(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
)
|
||||
|
||||
auth_header, api_base = self._get_token_and_url(
|
||||
model=model,
|
||||
gemini_api_key=None,
|
||||
auth_header=_auth_header,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
stream=stream,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
should_use_v1beta1_features=should_use_v1beta1_features,
|
||||
)
|
||||
|
||||
headers = set_headers(auth_header=auth_header, extra_headers=extra_headers)
|
||||
|
||||
request_body = await async_transform_request_body(**data) # type: ignore
|
||||
_async_client_params = {}
|
||||
if timeout:
|
||||
|
@ -1373,22 +1243,6 @@ class VertexLLM(BaseLLM):
|
|||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore
|
||||
|
||||
should_use_v1beta1_features = self.is_using_v1beta1_features(
|
||||
optional_params=optional_params
|
||||
)
|
||||
|
||||
auth_header, url = self._get_token_and_url(
|
||||
model=model,
|
||||
gemini_api_key=gemini_api_key,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
stream=stream,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
should_use_v1beta1_features=should_use_v1beta1_features,
|
||||
)
|
||||
|
||||
transform_request_params = {
|
||||
"gemini_api_key": gemini_api_key,
|
||||
"messages": messages,
|
||||
|
@ -1403,8 +1257,6 @@ class VertexLLM(BaseLLM):
|
|||
"litellm_params": litellm_params,
|
||||
}
|
||||
|
||||
headers = set_headers(auth_header=auth_header, extra_headers=extra_headers)
|
||||
|
||||
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||
if acompletion:
|
||||
### ASYNC STREAMING
|
||||
|
@ -1412,7 +1264,7 @@ class VertexLLM(BaseLLM):
|
|||
return self.async_streaming(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=url,
|
||||
api_base=api_base,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
|
@ -1424,14 +1276,18 @@ class VertexLLM(BaseLLM):
|
|||
timeout=timeout,
|
||||
client=client, # type: ignore
|
||||
data=transform_request_params,
|
||||
headers=headers,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
extra_headers=extra_headers,
|
||||
)
|
||||
### ASYNC COMPLETION
|
||||
return self.async_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=transform_request_params, # type: ignore
|
||||
api_base=url,
|
||||
api_base=api_base,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
|
@ -1442,10 +1298,35 @@ class VertexLLM(BaseLLM):
|
|||
logger_fn=logger_fn,
|
||||
timeout=timeout,
|
||||
client=client, # type: ignore
|
||||
headers=headers,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
extra_headers=extra_headers,
|
||||
)
|
||||
|
||||
## SYNC STREAMING CALL ##
|
||||
should_use_v1beta1_features = self.is_using_v1beta1_features(
|
||||
optional_params=optional_params
|
||||
)
|
||||
|
||||
_auth_header, vertex_project = self._ensure_access_token(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
)
|
||||
|
||||
auth_header, url = self._get_token_and_url(
|
||||
model=model,
|
||||
gemini_api_key=gemini_api_key,
|
||||
auth_header=_auth_header,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
stream=stream,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
should_use_v1beta1_features=should_use_v1beta1_features,
|
||||
)
|
||||
headers = set_headers(auth_header=auth_header, extra_headers=extra_headers)
|
||||
|
||||
## TRANSFORMATION ##
|
||||
data = sync_transform_request_body(**transform_request_params)
|
||||
|
||||
|
@ -1460,6 +1341,7 @@ class VertexLLM(BaseLLM):
|
|||
},
|
||||
)
|
||||
|
||||
## SYNC STREAMING CALL ##
|
||||
if stream is True:
|
||||
request_data_str = json.dumps(data)
|
||||
streaming_response = CustomStreamWrapper(
|
||||
|
|
|
@ -43,8 +43,14 @@ class GoogleBatchEmbeddings(VertexLLM):
|
|||
client=None,
|
||||
) -> EmbeddingResponse:
|
||||
|
||||
_auth_header, vertex_project = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
)
|
||||
|
||||
auth_header, url = self._get_token_and_url(
|
||||
model=model,
|
||||
auth_header=_auth_header,
|
||||
gemini_api_key=api_key,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
|
|
|
@ -43,8 +43,15 @@ class VertexMultimodalEmbedding(VertexLLM):
|
|||
timeout=300,
|
||||
client=None,
|
||||
):
|
||||
|
||||
_auth_header, vertex_project = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
)
|
||||
|
||||
auth_header, url = self._get_token_and_url(
|
||||
model=model,
|
||||
auth_header=_auth_header,
|
||||
gemini_api_key=api_key,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
|
|
|
@ -65,8 +65,15 @@ class VertexTextToSpeechAPI(VertexLLM):
|
|||
import base64
|
||||
|
||||
####### Authenticate with Vertex AI ########
|
||||
|
||||
_auth_header, vertex_project = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
)
|
||||
|
||||
auth_header, _ = self._get_token_and_url(
|
||||
model="",
|
||||
auth_header=_auth_header,
|
||||
gemini_api_key=None,
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_project=vertex_project,
|
||||
|
|
255
litellm/llms/vertex_ai_and_google_ai_studio/vertex_llm_base.py
Normal file
255
litellm/llms/vertex_ai_and_google_ai_studio/vertex_llm_base.py
Normal file
|
@ -0,0 +1,255 @@
|
|||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.asyncify import asyncify
|
||||
from litellm.llms.base import BaseLLM
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
|
||||
from .common_utils import (
|
||||
VertexAIError,
|
||||
_get_gemini_url,
|
||||
_get_vertex_url,
|
||||
all_gemini_url_modes,
|
||||
get_supports_system_message,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google.auth.credentials import Credentials as GoogleCredentialsObject
|
||||
else:
|
||||
GoogleCredentialsObject = Any
|
||||
|
||||
|
||||
class VertexBase(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.access_token: Optional[str] = None
|
||||
self.refresh_token: Optional[str] = None
|
||||
self._credentials: Optional[GoogleCredentialsObject] = None
|
||||
self.project_id: Optional[str] = None
|
||||
self.async_handler: Optional[AsyncHTTPHandler] = None
|
||||
|
||||
def get_vertex_region(self, vertex_region: Optional[str]) -> str:
|
||||
return vertex_region or "us-central1"
|
||||
|
||||
def load_auth(
|
||||
self, credentials: Optional[str], project_id: Optional[str]
|
||||
) -> Tuple[Any, str]:
|
||||
import google.auth as google_auth
|
||||
from google.auth import identity_pool
|
||||
from google.auth.credentials import Credentials # type: ignore[import-untyped]
|
||||
from google.auth.transport.requests import (
|
||||
Request, # type: ignore[import-untyped]
|
||||
)
|
||||
|
||||
if credentials is not None and isinstance(credentials, str):
|
||||
import google.oauth2.service_account
|
||||
|
||||
verbose_logger.debug(
|
||||
"Vertex: Loading vertex credentials from %s", credentials
|
||||
)
|
||||
verbose_logger.debug(
|
||||
"Vertex: checking if credentials is a valid path, os.path.exists(%s)=%s, current dir %s",
|
||||
credentials,
|
||||
os.path.exists(credentials),
|
||||
os.getcwd(),
|
||||
)
|
||||
|
||||
try:
|
||||
if os.path.exists(credentials):
|
||||
json_obj = json.load(open(credentials))
|
||||
else:
|
||||
json_obj = json.loads(credentials)
|
||||
except Exception:
|
||||
raise Exception(
|
||||
"Unable to load vertex credentials from environment. Got={}".format(
|
||||
credentials
|
||||
)
|
||||
)
|
||||
|
||||
# Check if the JSON object contains Workload Identity Federation configuration
|
||||
if "type" in json_obj and json_obj["type"] == "external_account":
|
||||
creds = identity_pool.Credentials.from_info(json_obj)
|
||||
else:
|
||||
creds = (
|
||||
google.oauth2.service_account.Credentials.from_service_account_info(
|
||||
json_obj,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
)
|
||||
|
||||
if project_id is None:
|
||||
project_id = creds.project_id
|
||||
else:
|
||||
creds, creds_project_id = google_auth.default(
|
||||
quota_project_id=project_id,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
if project_id is None:
|
||||
project_id = creds_project_id
|
||||
|
||||
creds.refresh(Request())
|
||||
|
||||
if not project_id:
|
||||
raise ValueError("Could not resolve project_id")
|
||||
|
||||
if not isinstance(project_id, str):
|
||||
raise TypeError(
|
||||
f"Expected project_id to be a str but got {type(project_id)}"
|
||||
)
|
||||
|
||||
return creds, project_id
|
||||
|
||||
def refresh_auth(self, credentials: Any) -> None:
|
||||
from google.auth.transport.requests import (
|
||||
Request, # type: ignore[import-untyped]
|
||||
)
|
||||
|
||||
credentials.refresh(Request())
|
||||
|
||||
def _ensure_access_token(
|
||||
self, credentials: Optional[str], project_id: Optional[str]
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Returns auth token and project id
|
||||
"""
|
||||
if self.access_token is not None:
|
||||
if project_id is not None:
|
||||
return self.access_token, project_id
|
||||
elif self.project_id is not None:
|
||||
return self.access_token, self.project_id
|
||||
|
||||
if not self._credentials:
|
||||
self._credentials, cred_project_id = self.load_auth(
|
||||
credentials=credentials, project_id=project_id
|
||||
)
|
||||
if not self.project_id:
|
||||
self.project_id = project_id or cred_project_id
|
||||
else:
|
||||
if self._credentials.expired or not self._credentials.token:
|
||||
self.refresh_auth(self._credentials)
|
||||
|
||||
if not self.project_id:
|
||||
self.project_id = self._credentials.quota_project_id
|
||||
|
||||
if not self.project_id:
|
||||
raise ValueError("Could not resolve project_id")
|
||||
|
||||
if not self._credentials or not self._credentials.token:
|
||||
raise RuntimeError("Could not resolve API token from the environment")
|
||||
|
||||
return self._credentials.token, project_id or self.project_id
|
||||
|
||||
def is_using_v1beta1_features(self, optional_params: dict) -> bool:
|
||||
"""
|
||||
VertexAI only supports ContextCaching on v1beta1
|
||||
|
||||
use this helper to decide if request should be sent to v1 or v1beta1
|
||||
|
||||
Returns v1beta1 if context caching is enabled
|
||||
Returns v1 in all other cases
|
||||
"""
|
||||
if "cached_content" in optional_params:
|
||||
return True
|
||||
if "CachedContent" in optional_params:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_token_and_url(
|
||||
self,
|
||||
model: str,
|
||||
auth_header: str,
|
||||
gemini_api_key: Optional[str],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_credentials: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||
api_base: Optional[str],
|
||||
should_use_v1beta1_features: Optional[bool] = False,
|
||||
mode: all_gemini_url_modes = "chat",
|
||||
) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
Internal function. Returns the token and url for the call.
|
||||
|
||||
Handles logic if it's google ai studio vs. vertex ai.
|
||||
|
||||
Returns
|
||||
token, url
|
||||
"""
|
||||
if custom_llm_provider == "gemini":
|
||||
url, endpoint = _get_gemini_url(
|
||||
mode=mode,
|
||||
model=model,
|
||||
stream=stream,
|
||||
gemini_api_key=gemini_api_key,
|
||||
)
|
||||
else:
|
||||
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
version: Literal["v1beta1", "v1"] = (
|
||||
"v1beta1" if should_use_v1beta1_features is True else "v1"
|
||||
)
|
||||
url, endpoint = _get_vertex_url(
|
||||
mode=mode,
|
||||
model=model,
|
||||
stream=stream,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_api_version=version,
|
||||
)
|
||||
|
||||
if (
|
||||
api_base is not None
|
||||
): # for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317
|
||||
if custom_llm_provider == "gemini":
|
||||
url = "{}:{}".format(api_base, endpoint)
|
||||
if gemini_api_key is None:
|
||||
raise ValueError(
|
||||
"Missing gemini_api_key, please set `GEMINI_API_KEY`"
|
||||
)
|
||||
auth_header = (
|
||||
gemini_api_key # cloudflare expects api key as bearer token
|
||||
)
|
||||
else:
|
||||
url = "{}:{}".format(api_base, endpoint)
|
||||
|
||||
if stream is True:
|
||||
url = url + "?alt=sse"
|
||||
|
||||
return auth_header, url
|
||||
|
||||
async def _ensure_access_token_async(
|
||||
self, credentials: Optional[str], project_id: Optional[str]
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Async version of _ensure_access_token
|
||||
"""
|
||||
if self.access_token is not None:
|
||||
if project_id is not None:
|
||||
return self.access_token, project_id
|
||||
elif self.project_id is not None:
|
||||
return self.access_token, self.project_id
|
||||
|
||||
if not self._credentials:
|
||||
self._credentials, cred_project_id = await asyncify(self.load_auth)(
|
||||
credentials=credentials, project_id=project_id
|
||||
)
|
||||
if not self.project_id:
|
||||
self.project_id = project_id or cred_project_id
|
||||
else:
|
||||
if self._credentials.expired or not self._credentials.token:
|
||||
await asyncify(self.refresh_auth)(self._credentials)
|
||||
|
||||
if not self.project_id:
|
||||
self.project_id = self._credentials.quota_project_id
|
||||
|
||||
if not self.project_id:
|
||||
raise ValueError("Could not resolve project_id")
|
||||
|
||||
if not self._credentials or not self._credentials.token:
|
||||
raise RuntimeError("Could not resolve API token from the environment")
|
||||
|
||||
return self._credentials.token, project_id or self.project_id
|
|
@ -150,8 +150,15 @@ async def vertex_proxy_route(
|
|||
|
||||
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
|
||||
|
||||
_auth_header, vertex_project = (
|
||||
await vertex_fine_tuning_apis_instance._ensure_access_token_async(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
)
|
||||
)
|
||||
|
||||
auth_header, _ = vertex_fine_tuning_apis_instance._get_token_and_url(
|
||||
model="",
|
||||
auth_header=_auth_header,
|
||||
gemini_api_key=None,
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_project=vertex_project,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue