fix vertex use async func to set auth creds

This commit is contained in:
Ishaan Jaff 2024-09-10 16:12:18 -07:00
parent 26ae86e59b
commit 1c6f8b1be2
8 changed files with 420 additions and 230 deletions

View file

@ -34,10 +34,18 @@ class GCSBucketBase(CustomLogger):
async def construct_request_headers(self) -> Dict[str, str]:
from litellm import vertex_chat_completion
_auth_header, vertex_project = (
await vertex_chat_completion._ensure_access_token_async(
credentials=self.path_service_account_json,
project_id=None,
)
)
auth_header, _ = vertex_chat_completion._get_token_and_url(
model="gcs-bucket",
auth_header=_auth_header,
vertex_credentials=self.path_service_account_json,
vertex_project=None,
vertex_project=vertex_project,
vertex_location=None,
gemini_api_key=None,
stream=None,
@ -55,10 +63,16 @@ class GCSBucketBase(CustomLogger):
def sync_construct_request_headers(self) -> Dict[str, str]:
from litellm import vertex_chat_completion
_auth_header, vertex_project = vertex_chat_completion._ensure_access_token(
credentials=self.path_service_account_json,
project_id=None,
)
auth_header, _ = vertex_chat_completion._get_token_and_url(
model="gcs-bucket",
auth_header=_auth_header,
vertex_credentials=self.path_service_account_json,
vertex_project=None,
vertex_project=vertex_project,
vertex_location=None,
gemini_api_key=None,
stream=None,

View file

@ -185,8 +185,14 @@ class VertexFineTuningAPI(VertexLLM):
"creating fine tuning job, args= %s", create_fine_tuning_job_data
)
_auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
)
auth_header, _ = self._get_token_and_url(
model="",
auth_header=_auth_header,
gemini_api_key=None,
vertex_credentials=vertex_credentials,
vertex_project=vertex_project,
@ -251,8 +257,14 @@ class VertexFineTuningAPI(VertexLLM):
vertex_credentials: str,
request_route: str,
):
_auth_header, vertex_project = await self._ensure_access_token_async(
credentials=vertex_credentials,
project_id=vertex_project,
)
auth_header, _ = self._get_token_and_url(
model="",
auth_header=_auth_header,
gemini_api_key=None,
vertex_credentials=vertex_credentials,
vertex_project=vertex_project,

View file

@ -72,17 +72,13 @@ from ..common_utils import (
all_gemini_url_modes,
get_supports_system_message,
)
from ..vertex_llm_base import VertexBase
from .transformation import (
async_transform_request_body,
set_headers,
sync_transform_request_body,
)
if TYPE_CHECKING:
from google.auth.credentials import Credentials as GoogleCredentialsObject
else:
GoogleCredentialsObject = Any
class VertexAIConfig:
"""
@ -821,14 +817,9 @@ def make_sync_call(
return completion_stream
class VertexLLM(BaseLLM):
class VertexLLM(VertexBase):
def __init__(self) -> None:
super().__init__()
self.access_token: Optional[str] = None
self.refresh_token: Optional[str] = None
self._credentials: Optional[GoogleCredentialsObject] = None
self.project_id: Optional[str] = None
self.async_handler: Optional[AsyncHTTPHandler] = None
def _process_response(
self,
@ -1057,201 +1048,13 @@ class VertexLLM(BaseLLM):
return model_response
def get_vertex_region(self, vertex_region: Optional[str]) -> str:
return vertex_region or "us-central1"
def load_auth(
self, credentials: Optional[str], project_id: Optional[str]
) -> Tuple[Any, str]:
import google.auth as google_auth
from google.auth import identity_pool
from google.auth.credentials import Credentials # type: ignore[import-untyped]
from google.auth.transport.requests import (
Request, # type: ignore[import-untyped]
)
if credentials is not None and isinstance(credentials, str):
import google.oauth2.service_account
verbose_logger.debug(
"Vertex: Loading vertex credentials from %s", credentials
)
verbose_logger.debug(
"Vertex: checking if credentials is a valid path, os.path.exists(%s)=%s, current dir %s",
credentials,
os.path.exists(credentials),
os.getcwd(),
)
try:
if os.path.exists(credentials):
json_obj = json.load(open(credentials))
else:
json_obj = json.loads(credentials)
except Exception:
raise Exception(
"Unable to load vertex credentials from environment. Got={}".format(
credentials
)
)
# Check if the JSON object contains Workload Identity Federation configuration
if "type" in json_obj and json_obj["type"] == "external_account":
creds = identity_pool.Credentials.from_info(json_obj)
else:
creds = (
google.oauth2.service_account.Credentials.from_service_account_info(
json_obj,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
)
if project_id is None:
project_id = creds.project_id
else:
creds, creds_project_id = google_auth.default(
quota_project_id=project_id,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
if project_id is None:
project_id = creds_project_id
creds.refresh(Request())
if not project_id:
raise ValueError("Could not resolve project_id")
if not isinstance(project_id, str):
raise TypeError(
f"Expected project_id to be a str but got {type(project_id)}"
)
return creds, project_id
def refresh_auth(self, credentials: Any) -> None:
from google.auth.transport.requests import (
Request, # type: ignore[import-untyped]
)
credentials.refresh(Request())
def _ensure_access_token(
self, credentials: Optional[str], project_id: Optional[str]
) -> Tuple[str, str]:
"""
Returns auth token and project id
"""
if self.access_token is not None:
if project_id is not None:
return self.access_token, project_id
elif self.project_id is not None:
return self.access_token, self.project_id
if not self._credentials:
self._credentials, cred_project_id = self.load_auth(
credentials=credentials, project_id=project_id
)
if not self.project_id:
self.project_id = project_id or cred_project_id
else:
if self._credentials.expired or not self._credentials.token:
self.refresh_auth(self._credentials)
if not self.project_id:
self.project_id = self._credentials.quota_project_id
if not self.project_id:
raise ValueError("Could not resolve project_id")
if not self._credentials or not self._credentials.token:
raise RuntimeError("Could not resolve API token from the environment")
return self._credentials.token, project_id or self.project_id
def is_using_v1beta1_features(self, optional_params: dict) -> bool:
"""
VertexAI only supports ContextCaching on v1beta1
use this helper to decide if request should be sent to v1 or v1beta1
Returns v1beta1 if context caching is enabled
Returns v1 in all other cases
"""
if "cached_content" in optional_params:
return True
if "CachedContent" in optional_params:
return True
return False
def _get_token_and_url(
self,
model: str,
gemini_api_key: Optional[str],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
stream: Optional[bool],
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
api_base: Optional[str],
should_use_v1beta1_features: Optional[bool] = False,
mode: all_gemini_url_modes = "chat",
) -> Tuple[Optional[str], str]:
"""
Internal function. Returns the token and url for the call.
Handles logic if it's google ai studio vs. vertex ai.
Returns
token, url
"""
if custom_llm_provider == "gemini":
auth_header = None
url, endpoint = _get_gemini_url(
mode=mode,
model=model,
stream=stream,
gemini_api_key=gemini_api_key,
)
else:
auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
)
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
### SET RUNTIME ENDPOINT ###
version: Literal["v1beta1", "v1"] = (
"v1beta1" if should_use_v1beta1_features is True else "v1"
)
url, endpoint = _get_vertex_url(
mode=mode,
model=model,
stream=stream,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_api_version=version,
)
if (
api_base is not None
): # for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317
if custom_llm_provider == "gemini":
url = "{}:{}".format(api_base, endpoint)
auth_header = (
gemini_api_key # cloudflare expects api key as bearer token
)
else:
url = "{}:{}".format(api_base, endpoint)
if stream is True:
url = url + "?alt=sse"
return auth_header, url
async def async_streaming(
self,
model: str,
custom_llm_provider: Literal[
"vertex_ai", "vertex_ai_beta", "gemini"
], # if it's vertex_ai or gemini (google ai studio)
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
data: dict,
@ -1262,11 +1065,49 @@ class VertexLLM(BaseLLM):
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
api_base: Optional[str] = None,
client: Optional[AsyncHTTPHandler] = None,
vertex_project: Optional[str] = None,
vertex_location: Optional[str] = None,
vertex_credentials: Optional[str] = None,
extra_headers: Optional[dict] = None,
) -> CustomStreamWrapper:
request_body = await async_transform_request_body(**data) # type: ignore
should_use_v1beta1_features = self.is_using_v1beta1_features(
optional_params=optional_params
)
_auth_header, vertex_project = await self._ensure_access_token_async(
credentials=vertex_credentials, project_id=vertex_project
)
auth_header, api_base = self._get_token_and_url(
model=model,
gemini_api_key=None,
auth_header=_auth_header,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=stream,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
should_use_v1beta1_features=should_use_v1beta1_features,
)
headers = set_headers(auth_header=auth_header, extra_headers=extra_headers)
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": headers,
},
)
request_body_str = json.dumps(request_body)
streaming_response = CustomStreamWrapper(
completion_stream=None,
@ -1290,21 +1131,50 @@ class VertexLLM(BaseLLM):
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
data: dict,
custom_llm_provider: Literal[
"vertex_ai", "vertex_ai_beta", "gemini"
], # if it's vertex_ai or gemini (google ai studio)
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
stream,
optional_params: dict,
litellm_params: dict,
headers: dict,
logger_fn=None,
api_base: Optional[str] = None,
client: Optional[AsyncHTTPHandler] = None,
vertex_project: Optional[str] = None,
vertex_location: Optional[str] = None,
vertex_credentials: Optional[str] = None,
extra_headers: Optional[dict] = None,
) -> Union[ModelResponse, CustomStreamWrapper]:
should_use_v1beta1_features = self.is_using_v1beta1_features(
optional_params=optional_params
)
_auth_header, vertex_project = await self._ensure_access_token_async(
credentials=vertex_credentials, project_id=vertex_project
)
auth_header, api_base = self._get_token_and_url(
model=model,
gemini_api_key=None,
auth_header=_auth_header,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=stream,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
should_use_v1beta1_features=should_use_v1beta1_features,
)
headers = set_headers(auth_header=auth_header, extra_headers=extra_headers)
request_body = await async_transform_request_body(**data) # type: ignore
_async_client_params = {}
if timeout:
@ -1373,22 +1243,6 @@ class VertexLLM(BaseLLM):
) -> Union[ModelResponse, CustomStreamWrapper]:
stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore
should_use_v1beta1_features = self.is_using_v1beta1_features(
optional_params=optional_params
)
auth_header, url = self._get_token_and_url(
model=model,
gemini_api_key=gemini_api_key,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=stream,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
should_use_v1beta1_features=should_use_v1beta1_features,
)
transform_request_params = {
"gemini_api_key": gemini_api_key,
"messages": messages,
@ -1403,8 +1257,6 @@ class VertexLLM(BaseLLM):
"litellm_params": litellm_params,
}
headers = set_headers(auth_header=auth_header, extra_headers=extra_headers)
### ROUTING (ASYNC, STREAMING, SYNC)
if acompletion:
### ASYNC STREAMING
@ -1412,7 +1264,7 @@ class VertexLLM(BaseLLM):
return self.async_streaming(
model=model,
messages=messages,
api_base=url,
api_base=api_base,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
@ -1424,14 +1276,18 @@ class VertexLLM(BaseLLM):
timeout=timeout,
client=client, # type: ignore
data=transform_request_params,
headers=headers,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
custom_llm_provider=custom_llm_provider,
extra_headers=extra_headers,
)
### ASYNC COMPLETION
return self.async_completion(
model=model,
messages=messages,
data=transform_request_params, # type: ignore
api_base=url,
api_base=api_base,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
@ -1442,10 +1298,35 @@ class VertexLLM(BaseLLM):
logger_fn=logger_fn,
timeout=timeout,
client=client, # type: ignore
headers=headers,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
custom_llm_provider=custom_llm_provider,
extra_headers=extra_headers,
)
## SYNC STREAMING CALL ##
should_use_v1beta1_features = self.is_using_v1beta1_features(
optional_params=optional_params
)
_auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
)
auth_header, url = self._get_token_and_url(
model=model,
gemini_api_key=gemini_api_key,
auth_header=_auth_header,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=stream,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
should_use_v1beta1_features=should_use_v1beta1_features,
)
headers = set_headers(auth_header=auth_header, extra_headers=extra_headers)
## TRANSFORMATION ##
data = sync_transform_request_body(**transform_request_params)
@ -1460,6 +1341,7 @@ class VertexLLM(BaseLLM):
},
)
## SYNC STREAMING CALL ##
if stream is True:
request_data_str = json.dumps(data)
streaming_response = CustomStreamWrapper(

View file

@ -43,8 +43,14 @@ class GoogleBatchEmbeddings(VertexLLM):
client=None,
) -> EmbeddingResponse:
_auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
)
auth_header, url = self._get_token_and_url(
model=model,
auth_header=_auth_header,
gemini_api_key=api_key,
vertex_project=vertex_project,
vertex_location=vertex_location,

View file

@ -43,8 +43,15 @@ class VertexMultimodalEmbedding(VertexLLM):
timeout=300,
client=None,
):
_auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
)
auth_header, url = self._get_token_and_url(
model=model,
auth_header=_auth_header,
gemini_api_key=api_key,
vertex_project=vertex_project,
vertex_location=vertex_location,

View file

@ -65,8 +65,15 @@ class VertexTextToSpeechAPI(VertexLLM):
import base64
####### Authenticate with Vertex AI ########
_auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
)
auth_header, _ = self._get_token_and_url(
model="",
auth_header=_auth_header,
gemini_api_key=None,
vertex_credentials=vertex_credentials,
vertex_project=vertex_project,

View file

@ -0,0 +1,255 @@
import json
import os
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.asyncify import asyncify
from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .common_utils import (
VertexAIError,
_get_gemini_url,
_get_vertex_url,
all_gemini_url_modes,
get_supports_system_message,
)
if TYPE_CHECKING:
from google.auth.credentials import Credentials as GoogleCredentialsObject
else:
GoogleCredentialsObject = Any
class VertexBase(BaseLLM):
def __init__(self) -> None:
super().__init__()
self.access_token: Optional[str] = None
self.refresh_token: Optional[str] = None
self._credentials: Optional[GoogleCredentialsObject] = None
self.project_id: Optional[str] = None
self.async_handler: Optional[AsyncHTTPHandler] = None
def get_vertex_region(self, vertex_region: Optional[str]) -> str:
return vertex_region or "us-central1"
def load_auth(
self, credentials: Optional[str], project_id: Optional[str]
) -> Tuple[Any, str]:
import google.auth as google_auth
from google.auth import identity_pool
from google.auth.credentials import Credentials # type: ignore[import-untyped]
from google.auth.transport.requests import (
Request, # type: ignore[import-untyped]
)
if credentials is not None and isinstance(credentials, str):
import google.oauth2.service_account
verbose_logger.debug(
"Vertex: Loading vertex credentials from %s", credentials
)
verbose_logger.debug(
"Vertex: checking if credentials is a valid path, os.path.exists(%s)=%s, current dir %s",
credentials,
os.path.exists(credentials),
os.getcwd(),
)
try:
if os.path.exists(credentials):
json_obj = json.load(open(credentials))
else:
json_obj = json.loads(credentials)
except Exception:
raise Exception(
"Unable to load vertex credentials from environment. Got={}".format(
credentials
)
)
# Check if the JSON object contains Workload Identity Federation configuration
if "type" in json_obj and json_obj["type"] == "external_account":
creds = identity_pool.Credentials.from_info(json_obj)
else:
creds = (
google.oauth2.service_account.Credentials.from_service_account_info(
json_obj,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
)
if project_id is None:
project_id = creds.project_id
else:
creds, creds_project_id = google_auth.default(
quota_project_id=project_id,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
if project_id is None:
project_id = creds_project_id
creds.refresh(Request())
if not project_id:
raise ValueError("Could not resolve project_id")
if not isinstance(project_id, str):
raise TypeError(
f"Expected project_id to be a str but got {type(project_id)}"
)
return creds, project_id
def refresh_auth(self, credentials: Any) -> None:
from google.auth.transport.requests import (
Request, # type: ignore[import-untyped]
)
credentials.refresh(Request())
def _ensure_access_token(
self, credentials: Optional[str], project_id: Optional[str]
) -> Tuple[str, str]:
"""
Returns auth token and project id
"""
if self.access_token is not None:
if project_id is not None:
return self.access_token, project_id
elif self.project_id is not None:
return self.access_token, self.project_id
if not self._credentials:
self._credentials, cred_project_id = self.load_auth(
credentials=credentials, project_id=project_id
)
if not self.project_id:
self.project_id = project_id or cred_project_id
else:
if self._credentials.expired or not self._credentials.token:
self.refresh_auth(self._credentials)
if not self.project_id:
self.project_id = self._credentials.quota_project_id
if not self.project_id:
raise ValueError("Could not resolve project_id")
if not self._credentials or not self._credentials.token:
raise RuntimeError("Could not resolve API token from the environment")
return self._credentials.token, project_id or self.project_id
def is_using_v1beta1_features(self, optional_params: dict) -> bool:
"""
VertexAI only supports ContextCaching on v1beta1
use this helper to decide if request should be sent to v1 or v1beta1
Returns v1beta1 if context caching is enabled
Returns v1 in all other cases
"""
if "cached_content" in optional_params:
return True
if "CachedContent" in optional_params:
return True
return False
def _get_token_and_url(
self,
model: str,
auth_header: str,
gemini_api_key: Optional[str],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
stream: Optional[bool],
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
api_base: Optional[str],
should_use_v1beta1_features: Optional[bool] = False,
mode: all_gemini_url_modes = "chat",
) -> Tuple[Optional[str], str]:
"""
Internal function. Returns the token and url for the call.
Handles logic if it's google ai studio vs. vertex ai.
Returns
token, url
"""
if custom_llm_provider == "gemini":
url, endpoint = _get_gemini_url(
mode=mode,
model=model,
stream=stream,
gemini_api_key=gemini_api_key,
)
else:
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
### SET RUNTIME ENDPOINT ###
version: Literal["v1beta1", "v1"] = (
"v1beta1" if should_use_v1beta1_features is True else "v1"
)
url, endpoint = _get_vertex_url(
mode=mode,
model=model,
stream=stream,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_api_version=version,
)
if (
api_base is not None
): # for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317
if custom_llm_provider == "gemini":
url = "{}:{}".format(api_base, endpoint)
if gemini_api_key is None:
raise ValueError(
"Missing gemini_api_key, please set `GEMINI_API_KEY`"
)
auth_header = (
gemini_api_key # cloudflare expects api key as bearer token
)
else:
url = "{}:{}".format(api_base, endpoint)
if stream is True:
url = url + "?alt=sse"
return auth_header, url
async def _ensure_access_token_async(
self, credentials: Optional[str], project_id: Optional[str]
) -> Tuple[str, str]:
"""
Async version of _ensure_access_token
"""
if self.access_token is not None:
if project_id is not None:
return self.access_token, project_id
elif self.project_id is not None:
return self.access_token, self.project_id
if not self._credentials:
self._credentials, cred_project_id = await asyncify(self.load_auth)(
credentials=credentials, project_id=project_id
)
if not self.project_id:
self.project_id = project_id or cred_project_id
else:
if self._credentials.expired or not self._credentials.token:
await asyncify(self.refresh_auth)(self._credentials)
if not self.project_id:
self.project_id = self._credentials.quota_project_id
if not self.project_id:
raise ValueError("Could not resolve project_id")
if not self._credentials or not self._credentials.token:
raise RuntimeError("Could not resolve API token from the environment")
return self._credentials.token, project_id or self.project_id

View file

@ -150,8 +150,15 @@ async def vertex_proxy_route(
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
_auth_header, vertex_project = (
await vertex_fine_tuning_apis_instance._ensure_access_token_async(
credentials=vertex_credentials, project_id=vertex_project
)
)
auth_header, _ = vertex_fine_tuning_apis_instance._get_token_and_url(
model="",
auth_header=_auth_header,
gemini_api_key=None,
vertex_credentials=vertex_credentials,
vertex_project=vertex_project,