forked from phoenix/litellm-mirror
fix vertex use async func to set auth creds
This commit is contained in:
parent
26ae86e59b
commit
1c6f8b1be2
8 changed files with 420 additions and 230 deletions
|
@ -34,10 +34,18 @@ class GCSBucketBase(CustomLogger):
|
||||||
async def construct_request_headers(self) -> Dict[str, str]:
|
async def construct_request_headers(self) -> Dict[str, str]:
|
||||||
from litellm import vertex_chat_completion
|
from litellm import vertex_chat_completion
|
||||||
|
|
||||||
|
_auth_header, vertex_project = (
|
||||||
|
await vertex_chat_completion._ensure_access_token_async(
|
||||||
|
credentials=self.path_service_account_json,
|
||||||
|
project_id=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
auth_header, _ = vertex_chat_completion._get_token_and_url(
|
auth_header, _ = vertex_chat_completion._get_token_and_url(
|
||||||
model="gcs-bucket",
|
model="gcs-bucket",
|
||||||
|
auth_header=_auth_header,
|
||||||
vertex_credentials=self.path_service_account_json,
|
vertex_credentials=self.path_service_account_json,
|
||||||
vertex_project=None,
|
vertex_project=vertex_project,
|
||||||
vertex_location=None,
|
vertex_location=None,
|
||||||
gemini_api_key=None,
|
gemini_api_key=None,
|
||||||
stream=None,
|
stream=None,
|
||||||
|
@ -55,10 +63,16 @@ class GCSBucketBase(CustomLogger):
|
||||||
def sync_construct_request_headers(self) -> Dict[str, str]:
|
def sync_construct_request_headers(self) -> Dict[str, str]:
|
||||||
from litellm import vertex_chat_completion
|
from litellm import vertex_chat_completion
|
||||||
|
|
||||||
|
_auth_header, vertex_project = vertex_chat_completion._ensure_access_token(
|
||||||
|
credentials=self.path_service_account_json,
|
||||||
|
project_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
auth_header, _ = vertex_chat_completion._get_token_and_url(
|
auth_header, _ = vertex_chat_completion._get_token_and_url(
|
||||||
model="gcs-bucket",
|
model="gcs-bucket",
|
||||||
|
auth_header=_auth_header,
|
||||||
vertex_credentials=self.path_service_account_json,
|
vertex_credentials=self.path_service_account_json,
|
||||||
vertex_project=None,
|
vertex_project=vertex_project,
|
||||||
vertex_location=None,
|
vertex_location=None,
|
||||||
gemini_api_key=None,
|
gemini_api_key=None,
|
||||||
stream=None,
|
stream=None,
|
||||||
|
|
|
@ -185,8 +185,14 @@ class VertexFineTuningAPI(VertexLLM):
|
||||||
"creating fine tuning job, args= %s", create_fine_tuning_job_data
|
"creating fine tuning job, args= %s", create_fine_tuning_job_data
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_auth_header, vertex_project = self._ensure_access_token(
|
||||||
|
credentials=vertex_credentials,
|
||||||
|
project_id=vertex_project,
|
||||||
|
)
|
||||||
|
|
||||||
auth_header, _ = self._get_token_and_url(
|
auth_header, _ = self._get_token_and_url(
|
||||||
model="",
|
model="",
|
||||||
|
auth_header=_auth_header,
|
||||||
gemini_api_key=None,
|
gemini_api_key=None,
|
||||||
vertex_credentials=vertex_credentials,
|
vertex_credentials=vertex_credentials,
|
||||||
vertex_project=vertex_project,
|
vertex_project=vertex_project,
|
||||||
|
@ -251,8 +257,14 @@ class VertexFineTuningAPI(VertexLLM):
|
||||||
vertex_credentials: str,
|
vertex_credentials: str,
|
||||||
request_route: str,
|
request_route: str,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
_auth_header, vertex_project = await self._ensure_access_token_async(
|
||||||
|
credentials=vertex_credentials,
|
||||||
|
project_id=vertex_project,
|
||||||
|
)
|
||||||
auth_header, _ = self._get_token_and_url(
|
auth_header, _ = self._get_token_and_url(
|
||||||
model="",
|
model="",
|
||||||
|
auth_header=_auth_header,
|
||||||
gemini_api_key=None,
|
gemini_api_key=None,
|
||||||
vertex_credentials=vertex_credentials,
|
vertex_credentials=vertex_credentials,
|
||||||
vertex_project=vertex_project,
|
vertex_project=vertex_project,
|
||||||
|
|
|
@ -72,17 +72,13 @@ from ..common_utils import (
|
||||||
all_gemini_url_modes,
|
all_gemini_url_modes,
|
||||||
get_supports_system_message,
|
get_supports_system_message,
|
||||||
)
|
)
|
||||||
|
from ..vertex_llm_base import VertexBase
|
||||||
from .transformation import (
|
from .transformation import (
|
||||||
async_transform_request_body,
|
async_transform_request_body,
|
||||||
set_headers,
|
set_headers,
|
||||||
sync_transform_request_body,
|
sync_transform_request_body,
|
||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from google.auth.credentials import Credentials as GoogleCredentialsObject
|
|
||||||
else:
|
|
||||||
GoogleCredentialsObject = Any
|
|
||||||
|
|
||||||
|
|
||||||
class VertexAIConfig:
|
class VertexAIConfig:
|
||||||
"""
|
"""
|
||||||
|
@ -821,14 +817,9 @@ def make_sync_call(
|
||||||
return completion_stream
|
return completion_stream
|
||||||
|
|
||||||
|
|
||||||
class VertexLLM(BaseLLM):
|
class VertexLLM(VertexBase):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.access_token: Optional[str] = None
|
|
||||||
self.refresh_token: Optional[str] = None
|
|
||||||
self._credentials: Optional[GoogleCredentialsObject] = None
|
|
||||||
self.project_id: Optional[str] = None
|
|
||||||
self.async_handler: Optional[AsyncHTTPHandler] = None
|
|
||||||
|
|
||||||
def _process_response(
|
def _process_response(
|
||||||
self,
|
self,
|
||||||
|
@ -1057,201 +1048,13 @@ class VertexLLM(BaseLLM):
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
def get_vertex_region(self, vertex_region: Optional[str]) -> str:
|
|
||||||
return vertex_region or "us-central1"
|
|
||||||
|
|
||||||
def load_auth(
|
|
||||||
self, credentials: Optional[str], project_id: Optional[str]
|
|
||||||
) -> Tuple[Any, str]:
|
|
||||||
import google.auth as google_auth
|
|
||||||
from google.auth import identity_pool
|
|
||||||
from google.auth.credentials import Credentials # type: ignore[import-untyped]
|
|
||||||
from google.auth.transport.requests import (
|
|
||||||
Request, # type: ignore[import-untyped]
|
|
||||||
)
|
|
||||||
|
|
||||||
if credentials is not None and isinstance(credentials, str):
|
|
||||||
import google.oauth2.service_account
|
|
||||||
|
|
||||||
verbose_logger.debug(
|
|
||||||
"Vertex: Loading vertex credentials from %s", credentials
|
|
||||||
)
|
|
||||||
verbose_logger.debug(
|
|
||||||
"Vertex: checking if credentials is a valid path, os.path.exists(%s)=%s, current dir %s",
|
|
||||||
credentials,
|
|
||||||
os.path.exists(credentials),
|
|
||||||
os.getcwd(),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if os.path.exists(credentials):
|
|
||||||
json_obj = json.load(open(credentials))
|
|
||||||
else:
|
|
||||||
json_obj = json.loads(credentials)
|
|
||||||
except Exception:
|
|
||||||
raise Exception(
|
|
||||||
"Unable to load vertex credentials from environment. Got={}".format(
|
|
||||||
credentials
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if the JSON object contains Workload Identity Federation configuration
|
|
||||||
if "type" in json_obj and json_obj["type"] == "external_account":
|
|
||||||
creds = identity_pool.Credentials.from_info(json_obj)
|
|
||||||
else:
|
|
||||||
creds = (
|
|
||||||
google.oauth2.service_account.Credentials.from_service_account_info(
|
|
||||||
json_obj,
|
|
||||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if project_id is None:
|
|
||||||
project_id = creds.project_id
|
|
||||||
else:
|
|
||||||
creds, creds_project_id = google_auth.default(
|
|
||||||
quota_project_id=project_id,
|
|
||||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
|
||||||
)
|
|
||||||
if project_id is None:
|
|
||||||
project_id = creds_project_id
|
|
||||||
|
|
||||||
creds.refresh(Request())
|
|
||||||
|
|
||||||
if not project_id:
|
|
||||||
raise ValueError("Could not resolve project_id")
|
|
||||||
|
|
||||||
if not isinstance(project_id, str):
|
|
||||||
raise TypeError(
|
|
||||||
f"Expected project_id to be a str but got {type(project_id)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return creds, project_id
|
|
||||||
|
|
||||||
def refresh_auth(self, credentials: Any) -> None:
|
|
||||||
from google.auth.transport.requests import (
|
|
||||||
Request, # type: ignore[import-untyped]
|
|
||||||
)
|
|
||||||
|
|
||||||
credentials.refresh(Request())
|
|
||||||
|
|
||||||
def _ensure_access_token(
|
|
||||||
self, credentials: Optional[str], project_id: Optional[str]
|
|
||||||
) -> Tuple[str, str]:
|
|
||||||
"""
|
|
||||||
Returns auth token and project id
|
|
||||||
"""
|
|
||||||
if self.access_token is not None:
|
|
||||||
if project_id is not None:
|
|
||||||
return self.access_token, project_id
|
|
||||||
elif self.project_id is not None:
|
|
||||||
return self.access_token, self.project_id
|
|
||||||
|
|
||||||
if not self._credentials:
|
|
||||||
self._credentials, cred_project_id = self.load_auth(
|
|
||||||
credentials=credentials, project_id=project_id
|
|
||||||
)
|
|
||||||
if not self.project_id:
|
|
||||||
self.project_id = project_id or cred_project_id
|
|
||||||
else:
|
|
||||||
if self._credentials.expired or not self._credentials.token:
|
|
||||||
self.refresh_auth(self._credentials)
|
|
||||||
|
|
||||||
if not self.project_id:
|
|
||||||
self.project_id = self._credentials.quota_project_id
|
|
||||||
|
|
||||||
if not self.project_id:
|
|
||||||
raise ValueError("Could not resolve project_id")
|
|
||||||
|
|
||||||
if not self._credentials or not self._credentials.token:
|
|
||||||
raise RuntimeError("Could not resolve API token from the environment")
|
|
||||||
|
|
||||||
return self._credentials.token, project_id or self.project_id
|
|
||||||
|
|
||||||
def is_using_v1beta1_features(self, optional_params: dict) -> bool:
|
|
||||||
"""
|
|
||||||
VertexAI only supports ContextCaching on v1beta1
|
|
||||||
|
|
||||||
use this helper to decide if request should be sent to v1 or v1beta1
|
|
||||||
|
|
||||||
Returns v1beta1 if context caching is enabled
|
|
||||||
Returns v1 in all other cases
|
|
||||||
"""
|
|
||||||
if "cached_content" in optional_params:
|
|
||||||
return True
|
|
||||||
if "CachedContent" in optional_params:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _get_token_and_url(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
gemini_api_key: Optional[str],
|
|
||||||
vertex_project: Optional[str],
|
|
||||||
vertex_location: Optional[str],
|
|
||||||
vertex_credentials: Optional[str],
|
|
||||||
stream: Optional[bool],
|
|
||||||
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
|
||||||
api_base: Optional[str],
|
|
||||||
should_use_v1beta1_features: Optional[bool] = False,
|
|
||||||
mode: all_gemini_url_modes = "chat",
|
|
||||||
) -> Tuple[Optional[str], str]:
|
|
||||||
"""
|
|
||||||
Internal function. Returns the token and url for the call.
|
|
||||||
|
|
||||||
Handles logic if it's google ai studio vs. vertex ai.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
token, url
|
|
||||||
"""
|
|
||||||
if custom_llm_provider == "gemini":
|
|
||||||
auth_header = None
|
|
||||||
url, endpoint = _get_gemini_url(
|
|
||||||
mode=mode,
|
|
||||||
model=model,
|
|
||||||
stream=stream,
|
|
||||||
gemini_api_key=gemini_api_key,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
auth_header, vertex_project = self._ensure_access_token(
|
|
||||||
credentials=vertex_credentials, project_id=vertex_project
|
|
||||||
)
|
|
||||||
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
|
|
||||||
|
|
||||||
### SET RUNTIME ENDPOINT ###
|
|
||||||
version: Literal["v1beta1", "v1"] = (
|
|
||||||
"v1beta1" if should_use_v1beta1_features is True else "v1"
|
|
||||||
)
|
|
||||||
url, endpoint = _get_vertex_url(
|
|
||||||
mode=mode,
|
|
||||||
model=model,
|
|
||||||
stream=stream,
|
|
||||||
vertex_project=vertex_project,
|
|
||||||
vertex_location=vertex_location,
|
|
||||||
vertex_api_version=version,
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
api_base is not None
|
|
||||||
): # for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317
|
|
||||||
if custom_llm_provider == "gemini":
|
|
||||||
url = "{}:{}".format(api_base, endpoint)
|
|
||||||
auth_header = (
|
|
||||||
gemini_api_key # cloudflare expects api key as bearer token
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
url = "{}:{}".format(api_base, endpoint)
|
|
||||||
|
|
||||||
if stream is True:
|
|
||||||
url = url + "?alt=sse"
|
|
||||||
|
|
||||||
return auth_header, url
|
|
||||||
|
|
||||||
async def async_streaming(
|
async def async_streaming(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
custom_llm_provider: Literal[
|
||||||
|
"vertex_ai", "vertex_ai_beta", "gemini"
|
||||||
|
], # if it's vertex_ai or gemini (google ai studio)
|
||||||
messages: list,
|
messages: list,
|
||||||
api_base: str,
|
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
data: dict,
|
data: dict,
|
||||||
|
@ -1262,11 +1065,49 @@ class VertexLLM(BaseLLM):
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
headers={},
|
api_base: Optional[str] = None,
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
vertex_project: Optional[str] = None,
|
||||||
|
vertex_location: Optional[str] = None,
|
||||||
|
vertex_credentials: Optional[str] = None,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
) -> CustomStreamWrapper:
|
) -> CustomStreamWrapper:
|
||||||
request_body = await async_transform_request_body(**data) # type: ignore
|
request_body = await async_transform_request_body(**data) # type: ignore
|
||||||
|
|
||||||
|
should_use_v1beta1_features = self.is_using_v1beta1_features(
|
||||||
|
optional_params=optional_params
|
||||||
|
)
|
||||||
|
|
||||||
|
_auth_header, vertex_project = await self._ensure_access_token_async(
|
||||||
|
credentials=vertex_credentials, project_id=vertex_project
|
||||||
|
)
|
||||||
|
|
||||||
|
auth_header, api_base = self._get_token_and_url(
|
||||||
|
model=model,
|
||||||
|
gemini_api_key=None,
|
||||||
|
auth_header=_auth_header,
|
||||||
|
vertex_project=vertex_project,
|
||||||
|
vertex_location=vertex_location,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
stream=stream,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
api_base=api_base,
|
||||||
|
should_use_v1beta1_features=should_use_v1beta1_features,
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = set_headers(auth_header=auth_header, extra_headers=extra_headers)
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"api_base": api_base,
|
||||||
|
"headers": headers,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
request_body_str = json.dumps(request_body)
|
request_body_str = json.dumps(request_body)
|
||||||
streaming_response = CustomStreamWrapper(
|
streaming_response = CustomStreamWrapper(
|
||||||
completion_stream=None,
|
completion_stream=None,
|
||||||
|
@ -1290,21 +1131,50 @@ class VertexLLM(BaseLLM):
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: list,
|
messages: list,
|
||||||
api_base: str,
|
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
data: dict,
|
data: dict,
|
||||||
|
custom_llm_provider: Literal[
|
||||||
|
"vertex_ai", "vertex_ai_beta", "gemini"
|
||||||
|
], # if it's vertex_ai or gemini (google ai studio)
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
encoding,
|
encoding,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
stream,
|
stream,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
litellm_params: dict,
|
litellm_params: dict,
|
||||||
headers: dict,
|
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
vertex_project: Optional[str] = None,
|
||||||
|
vertex_location: Optional[str] = None,
|
||||||
|
vertex_credentials: Optional[str] = None,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
|
||||||
|
should_use_v1beta1_features = self.is_using_v1beta1_features(
|
||||||
|
optional_params=optional_params
|
||||||
|
)
|
||||||
|
|
||||||
|
_auth_header, vertex_project = await self._ensure_access_token_async(
|
||||||
|
credentials=vertex_credentials, project_id=vertex_project
|
||||||
|
)
|
||||||
|
|
||||||
|
auth_header, api_base = self._get_token_and_url(
|
||||||
|
model=model,
|
||||||
|
gemini_api_key=None,
|
||||||
|
auth_header=_auth_header,
|
||||||
|
vertex_project=vertex_project,
|
||||||
|
vertex_location=vertex_location,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
stream=stream,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
api_base=api_base,
|
||||||
|
should_use_v1beta1_features=should_use_v1beta1_features,
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = set_headers(auth_header=auth_header, extra_headers=extra_headers)
|
||||||
|
|
||||||
request_body = await async_transform_request_body(**data) # type: ignore
|
request_body = await async_transform_request_body(**data) # type: ignore
|
||||||
_async_client_params = {}
|
_async_client_params = {}
|
||||||
if timeout:
|
if timeout:
|
||||||
|
@ -1373,22 +1243,6 @@ class VertexLLM(BaseLLM):
|
||||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore
|
stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore
|
||||||
|
|
||||||
should_use_v1beta1_features = self.is_using_v1beta1_features(
|
|
||||||
optional_params=optional_params
|
|
||||||
)
|
|
||||||
|
|
||||||
auth_header, url = self._get_token_and_url(
|
|
||||||
model=model,
|
|
||||||
gemini_api_key=gemini_api_key,
|
|
||||||
vertex_project=vertex_project,
|
|
||||||
vertex_location=vertex_location,
|
|
||||||
vertex_credentials=vertex_credentials,
|
|
||||||
stream=stream,
|
|
||||||
custom_llm_provider=custom_llm_provider,
|
|
||||||
api_base=api_base,
|
|
||||||
should_use_v1beta1_features=should_use_v1beta1_features,
|
|
||||||
)
|
|
||||||
|
|
||||||
transform_request_params = {
|
transform_request_params = {
|
||||||
"gemini_api_key": gemini_api_key,
|
"gemini_api_key": gemini_api_key,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
|
@ -1403,8 +1257,6 @@ class VertexLLM(BaseLLM):
|
||||||
"litellm_params": litellm_params,
|
"litellm_params": litellm_params,
|
||||||
}
|
}
|
||||||
|
|
||||||
headers = set_headers(auth_header=auth_header, extra_headers=extra_headers)
|
|
||||||
|
|
||||||
### ROUTING (ASYNC, STREAMING, SYNC)
|
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||||
if acompletion:
|
if acompletion:
|
||||||
### ASYNC STREAMING
|
### ASYNC STREAMING
|
||||||
|
@ -1412,7 +1264,7 @@ class VertexLLM(BaseLLM):
|
||||||
return self.async_streaming(
|
return self.async_streaming(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_base=url,
|
api_base=api_base,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
@ -1424,14 +1276,18 @@ class VertexLLM(BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
client=client, # type: ignore
|
client=client, # type: ignore
|
||||||
data=transform_request_params,
|
data=transform_request_params,
|
||||||
headers=headers,
|
vertex_project=vertex_project,
|
||||||
|
vertex_location=vertex_location,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
extra_headers=extra_headers,
|
||||||
)
|
)
|
||||||
### ASYNC COMPLETION
|
### ASYNC COMPLETION
|
||||||
return self.async_completion(
|
return self.async_completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
data=transform_request_params, # type: ignore
|
data=transform_request_params, # type: ignore
|
||||||
api_base=url,
|
api_base=api_base,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
@ -1442,10 +1298,35 @@ class VertexLLM(BaseLLM):
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
client=client, # type: ignore
|
client=client, # type: ignore
|
||||||
headers=headers,
|
vertex_project=vertex_project,
|
||||||
|
vertex_location=vertex_location,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
extra_headers=extra_headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
## SYNC STREAMING CALL ##
|
should_use_v1beta1_features = self.is_using_v1beta1_features(
|
||||||
|
optional_params=optional_params
|
||||||
|
)
|
||||||
|
|
||||||
|
_auth_header, vertex_project = self._ensure_access_token(
|
||||||
|
credentials=vertex_credentials, project_id=vertex_project
|
||||||
|
)
|
||||||
|
|
||||||
|
auth_header, url = self._get_token_and_url(
|
||||||
|
model=model,
|
||||||
|
gemini_api_key=gemini_api_key,
|
||||||
|
auth_header=_auth_header,
|
||||||
|
vertex_project=vertex_project,
|
||||||
|
vertex_location=vertex_location,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
stream=stream,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
api_base=api_base,
|
||||||
|
should_use_v1beta1_features=should_use_v1beta1_features,
|
||||||
|
)
|
||||||
|
headers = set_headers(auth_header=auth_header, extra_headers=extra_headers)
|
||||||
|
|
||||||
## TRANSFORMATION ##
|
## TRANSFORMATION ##
|
||||||
data = sync_transform_request_body(**transform_request_params)
|
data = sync_transform_request_body(**transform_request_params)
|
||||||
|
|
||||||
|
@ -1460,6 +1341,7 @@ class VertexLLM(BaseLLM):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## SYNC STREAMING CALL ##
|
||||||
if stream is True:
|
if stream is True:
|
||||||
request_data_str = json.dumps(data)
|
request_data_str = json.dumps(data)
|
||||||
streaming_response = CustomStreamWrapper(
|
streaming_response = CustomStreamWrapper(
|
||||||
|
|
|
@ -43,8 +43,14 @@ class GoogleBatchEmbeddings(VertexLLM):
|
||||||
client=None,
|
client=None,
|
||||||
) -> EmbeddingResponse:
|
) -> EmbeddingResponse:
|
||||||
|
|
||||||
|
_auth_header, vertex_project = self._ensure_access_token(
|
||||||
|
credentials=vertex_credentials,
|
||||||
|
project_id=vertex_project,
|
||||||
|
)
|
||||||
|
|
||||||
auth_header, url = self._get_token_and_url(
|
auth_header, url = self._get_token_and_url(
|
||||||
model=model,
|
model=model,
|
||||||
|
auth_header=_auth_header,
|
||||||
gemini_api_key=api_key,
|
gemini_api_key=api_key,
|
||||||
vertex_project=vertex_project,
|
vertex_project=vertex_project,
|
||||||
vertex_location=vertex_location,
|
vertex_location=vertex_location,
|
||||||
|
|
|
@ -43,8 +43,15 @@ class VertexMultimodalEmbedding(VertexLLM):
|
||||||
timeout=300,
|
timeout=300,
|
||||||
client=None,
|
client=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
_auth_header, vertex_project = self._ensure_access_token(
|
||||||
|
credentials=vertex_credentials,
|
||||||
|
project_id=vertex_project,
|
||||||
|
)
|
||||||
|
|
||||||
auth_header, url = self._get_token_and_url(
|
auth_header, url = self._get_token_and_url(
|
||||||
model=model,
|
model=model,
|
||||||
|
auth_header=_auth_header,
|
||||||
gemini_api_key=api_key,
|
gemini_api_key=api_key,
|
||||||
vertex_project=vertex_project,
|
vertex_project=vertex_project,
|
||||||
vertex_location=vertex_location,
|
vertex_location=vertex_location,
|
||||||
|
|
|
@ -65,8 +65,15 @@ class VertexTextToSpeechAPI(VertexLLM):
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
####### Authenticate with Vertex AI ########
|
####### Authenticate with Vertex AI ########
|
||||||
|
|
||||||
|
_auth_header, vertex_project = self._ensure_access_token(
|
||||||
|
credentials=vertex_credentials,
|
||||||
|
project_id=vertex_project,
|
||||||
|
)
|
||||||
|
|
||||||
auth_header, _ = self._get_token_and_url(
|
auth_header, _ = self._get_token_and_url(
|
||||||
model="",
|
model="",
|
||||||
|
auth_header=_auth_header,
|
||||||
gemini_api_key=None,
|
gemini_api_key=None,
|
||||||
vertex_credentials=vertex_credentials,
|
vertex_credentials=vertex_credentials,
|
||||||
vertex_project=vertex_project,
|
vertex_project=vertex_project,
|
||||||
|
|
255
litellm/llms/vertex_ai_and_google_ai_studio/vertex_llm_base.py
Normal file
255
litellm/llms/vertex_ai_and_google_ai_studio/vertex_llm_base.py
Normal file
|
@ -0,0 +1,255 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
|
||||||
|
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.litellm_core_utils.asyncify import asyncify
|
||||||
|
from litellm.llms.base import BaseLLM
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
|
||||||
|
from .common_utils import (
|
||||||
|
VertexAIError,
|
||||||
|
_get_gemini_url,
|
||||||
|
_get_vertex_url,
|
||||||
|
all_gemini_url_modes,
|
||||||
|
get_supports_system_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from google.auth.credentials import Credentials as GoogleCredentialsObject
|
||||||
|
else:
|
||||||
|
GoogleCredentialsObject = Any
|
||||||
|
|
||||||
|
|
||||||
|
class VertexBase(BaseLLM):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.access_token: Optional[str] = None
|
||||||
|
self.refresh_token: Optional[str] = None
|
||||||
|
self._credentials: Optional[GoogleCredentialsObject] = None
|
||||||
|
self.project_id: Optional[str] = None
|
||||||
|
self.async_handler: Optional[AsyncHTTPHandler] = None
|
||||||
|
|
||||||
|
def get_vertex_region(self, vertex_region: Optional[str]) -> str:
|
||||||
|
return vertex_region or "us-central1"
|
||||||
|
|
||||||
|
def load_auth(
|
||||||
|
self, credentials: Optional[str], project_id: Optional[str]
|
||||||
|
) -> Tuple[Any, str]:
|
||||||
|
import google.auth as google_auth
|
||||||
|
from google.auth import identity_pool
|
||||||
|
from google.auth.credentials import Credentials # type: ignore[import-untyped]
|
||||||
|
from google.auth.transport.requests import (
|
||||||
|
Request, # type: ignore[import-untyped]
|
||||||
|
)
|
||||||
|
|
||||||
|
if credentials is not None and isinstance(credentials, str):
|
||||||
|
import google.oauth2.service_account
|
||||||
|
|
||||||
|
verbose_logger.debug(
|
||||||
|
"Vertex: Loading vertex credentials from %s", credentials
|
||||||
|
)
|
||||||
|
verbose_logger.debug(
|
||||||
|
"Vertex: checking if credentials is a valid path, os.path.exists(%s)=%s, current dir %s",
|
||||||
|
credentials,
|
||||||
|
os.path.exists(credentials),
|
||||||
|
os.getcwd(),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if os.path.exists(credentials):
|
||||||
|
json_obj = json.load(open(credentials))
|
||||||
|
else:
|
||||||
|
json_obj = json.loads(credentials)
|
||||||
|
except Exception:
|
||||||
|
raise Exception(
|
||||||
|
"Unable to load vertex credentials from environment. Got={}".format(
|
||||||
|
credentials
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the JSON object contains Workload Identity Federation configuration
|
||||||
|
if "type" in json_obj and json_obj["type"] == "external_account":
|
||||||
|
creds = identity_pool.Credentials.from_info(json_obj)
|
||||||
|
else:
|
||||||
|
creds = (
|
||||||
|
google.oauth2.service_account.Credentials.from_service_account_info(
|
||||||
|
json_obj,
|
||||||
|
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if project_id is None:
|
||||||
|
project_id = creds.project_id
|
||||||
|
else:
|
||||||
|
creds, creds_project_id = google_auth.default(
|
||||||
|
quota_project_id=project_id,
|
||||||
|
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||||
|
)
|
||||||
|
if project_id is None:
|
||||||
|
project_id = creds_project_id
|
||||||
|
|
||||||
|
creds.refresh(Request())
|
||||||
|
|
||||||
|
if not project_id:
|
||||||
|
raise ValueError("Could not resolve project_id")
|
||||||
|
|
||||||
|
if not isinstance(project_id, str):
|
||||||
|
raise TypeError(
|
||||||
|
f"Expected project_id to be a str but got {type(project_id)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return creds, project_id
|
||||||
|
|
||||||
|
def refresh_auth(self, credentials: Any) -> None:
|
||||||
|
from google.auth.transport.requests import (
|
||||||
|
Request, # type: ignore[import-untyped]
|
||||||
|
)
|
||||||
|
|
||||||
|
credentials.refresh(Request())
|
||||||
|
|
||||||
|
def _ensure_access_token(
|
||||||
|
self, credentials: Optional[str], project_id: Optional[str]
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Returns auth token and project id
|
||||||
|
"""
|
||||||
|
if self.access_token is not None:
|
||||||
|
if project_id is not None:
|
||||||
|
return self.access_token, project_id
|
||||||
|
elif self.project_id is not None:
|
||||||
|
return self.access_token, self.project_id
|
||||||
|
|
||||||
|
if not self._credentials:
|
||||||
|
self._credentials, cred_project_id = self.load_auth(
|
||||||
|
credentials=credentials, project_id=project_id
|
||||||
|
)
|
||||||
|
if not self.project_id:
|
||||||
|
self.project_id = project_id or cred_project_id
|
||||||
|
else:
|
||||||
|
if self._credentials.expired or not self._credentials.token:
|
||||||
|
self.refresh_auth(self._credentials)
|
||||||
|
|
||||||
|
if not self.project_id:
|
||||||
|
self.project_id = self._credentials.quota_project_id
|
||||||
|
|
||||||
|
if not self.project_id:
|
||||||
|
raise ValueError("Could not resolve project_id")
|
||||||
|
|
||||||
|
if not self._credentials or not self._credentials.token:
|
||||||
|
raise RuntimeError("Could not resolve API token from the environment")
|
||||||
|
|
||||||
|
return self._credentials.token, project_id or self.project_id
|
||||||
|
|
||||||
|
def is_using_v1beta1_features(self, optional_params: dict) -> bool:
|
||||||
|
"""
|
||||||
|
VertexAI only supports ContextCaching on v1beta1
|
||||||
|
|
||||||
|
use this helper to decide if request should be sent to v1 or v1beta1
|
||||||
|
|
||||||
|
Returns v1beta1 if context caching is enabled
|
||||||
|
Returns v1 in all other cases
|
||||||
|
"""
|
||||||
|
if "cached_content" in optional_params:
|
||||||
|
return True
|
||||||
|
if "CachedContent" in optional_params:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _get_token_and_url(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
auth_header: str,
|
||||||
|
gemini_api_key: Optional[str],
|
||||||
|
vertex_project: Optional[str],
|
||||||
|
vertex_location: Optional[str],
|
||||||
|
vertex_credentials: Optional[str],
|
||||||
|
stream: Optional[bool],
|
||||||
|
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||||
|
api_base: Optional[str],
|
||||||
|
should_use_v1beta1_features: Optional[bool] = False,
|
||||||
|
mode: all_gemini_url_modes = "chat",
|
||||||
|
) -> Tuple[Optional[str], str]:
|
||||||
|
"""
|
||||||
|
Internal function. Returns the token and url for the call.
|
||||||
|
|
||||||
|
Handles logic if it's google ai studio vs. vertex ai.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
token, url
|
||||||
|
"""
|
||||||
|
if custom_llm_provider == "gemini":
|
||||||
|
url, endpoint = _get_gemini_url(
|
||||||
|
mode=mode,
|
||||||
|
model=model,
|
||||||
|
stream=stream,
|
||||||
|
gemini_api_key=gemini_api_key,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
|
||||||
|
|
||||||
|
### SET RUNTIME ENDPOINT ###
|
||||||
|
version: Literal["v1beta1", "v1"] = (
|
||||||
|
"v1beta1" if should_use_v1beta1_features is True else "v1"
|
||||||
|
)
|
||||||
|
url, endpoint = _get_vertex_url(
|
||||||
|
mode=mode,
|
||||||
|
model=model,
|
||||||
|
stream=stream,
|
||||||
|
vertex_project=vertex_project,
|
||||||
|
vertex_location=vertex_location,
|
||||||
|
vertex_api_version=version,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
api_base is not None
|
||||||
|
): # for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317
|
||||||
|
if custom_llm_provider == "gemini":
|
||||||
|
url = "{}:{}".format(api_base, endpoint)
|
||||||
|
if gemini_api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Missing gemini_api_key, please set `GEMINI_API_KEY`"
|
||||||
|
)
|
||||||
|
auth_header = (
|
||||||
|
gemini_api_key # cloudflare expects api key as bearer token
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
url = "{}:{}".format(api_base, endpoint)
|
||||||
|
|
||||||
|
if stream is True:
|
||||||
|
url = url + "?alt=sse"
|
||||||
|
|
||||||
|
return auth_header, url
|
||||||
|
|
||||||
|
async def _ensure_access_token_async(
|
||||||
|
self, credentials: Optional[str], project_id: Optional[str]
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Async version of _ensure_access_token
|
||||||
|
"""
|
||||||
|
if self.access_token is not None:
|
||||||
|
if project_id is not None:
|
||||||
|
return self.access_token, project_id
|
||||||
|
elif self.project_id is not None:
|
||||||
|
return self.access_token, self.project_id
|
||||||
|
|
||||||
|
if not self._credentials:
|
||||||
|
self._credentials, cred_project_id = await asyncify(self.load_auth)(
|
||||||
|
credentials=credentials, project_id=project_id
|
||||||
|
)
|
||||||
|
if not self.project_id:
|
||||||
|
self.project_id = project_id or cred_project_id
|
||||||
|
else:
|
||||||
|
if self._credentials.expired or not self._credentials.token:
|
||||||
|
await asyncify(self.refresh_auth)(self._credentials)
|
||||||
|
|
||||||
|
if not self.project_id:
|
||||||
|
self.project_id = self._credentials.quota_project_id
|
||||||
|
|
||||||
|
if not self.project_id:
|
||||||
|
raise ValueError("Could not resolve project_id")
|
||||||
|
|
||||||
|
if not self._credentials or not self._credentials.token:
|
||||||
|
raise RuntimeError("Could not resolve API token from the environment")
|
||||||
|
|
||||||
|
return self._credentials.token, project_id or self.project_id
|
|
@ -150,8 +150,15 @@ async def vertex_proxy_route(
|
||||||
|
|
||||||
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
|
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
|
||||||
|
|
||||||
|
_auth_header, vertex_project = (
|
||||||
|
await vertex_fine_tuning_apis_instance._ensure_access_token_async(
|
||||||
|
credentials=vertex_credentials, project_id=vertex_project
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
auth_header, _ = vertex_fine_tuning_apis_instance._get_token_and_url(
|
auth_header, _ = vertex_fine_tuning_apis_instance._get_token_and_url(
|
||||||
model="",
|
model="",
|
||||||
|
auth_header=_auth_header,
|
||||||
gemini_api_key=None,
|
gemini_api_key=None,
|
||||||
vertex_credentials=vertex_credentials,
|
vertex_credentials=vertex_credentials,
|
||||||
vertex_project=vertex_project,
|
vertex_project=vertex_project,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue