diff --git a/litellm/llms/vertex_ai/common_utils.py b/litellm/llms/vertex_ai/common_utils.py index f7149c349a..4a4c428941 100644 --- a/litellm/llms/vertex_ai/common_utils.py +++ b/litellm/llms/vertex_ai/common_utils.py @@ -1,3 +1,4 @@ +import re from typing import Dict, List, Literal, Optional, Tuple, Union import httpx @@ -280,3 +281,62 @@ def _convert_vertex_datetime_to_openai_datetime(vertex_datetime: str) -> int: dt = datetime.strptime(vertex_datetime, "%Y-%m-%dT%H:%M:%S.%fZ") # Convert to Unix timestamp (seconds since epoch) return int(dt.timestamp()) + + +def get_vertex_project_id_from_url(url: str) -> Optional[str]: + """ + Get the vertex project id from the url + + `https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent` + """ + match = re.search(r"/projects/([^/]+)", url) + return match.group(1) if match else None + + +def get_vertex_location_from_url(url: str) -> Optional[str]: + """ + Get the vertex location from the url + + `https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent` + """ + match = re.search(r"/locations/([^/]+)", url) + return match.group(1) if match else None + + +def construct_target_url( + base_url: str, + requested_route: str, + default_vertex_location: Optional[str], + default_vertex_project: Optional[str], +) -> httpx.URL: + """ + Allow user to specify their own project id / location. + + If missing, use defaults + + Handle cachedContent scenario - https://github.com/BerriAI/litellm/issues/5460 + + Constructed Url: + POST https://LOCATION-aiplatform.googleapis.com/{version}/projects/PROJECT_ID/locations/LOCATION/cachedContents + """ + new_base_url = httpx.URL(base_url) + if "locations" in requested_route: # contains the target project id + location + updated_url = new_base_url.copy_with(path=requested_route) + return updated_url + """ + - Add endpoint version (e.g. v1beta for cachedContent, v1 for rest) + - Add default project id + - Add default location + """ + vertex_version: Literal["v1", "v1beta1"] = "v1" + if "cachedContent" in requested_route: + vertex_version = "v1beta1" + + base_requested_route = "{}/projects/{}/locations/{}".format( + vertex_version, default_vertex_project, default_vertex_location + ) + + updated_requested_route = "/" + base_requested_route + requested_route + + updated_url = new_base_url.copy_with(path=updated_requested_route) + return updated_url diff --git a/litellm/proxy/pass_through_endpoints/common_utils.py b/litellm/proxy/pass_through_endpoints/common_utils.py new file mode 100644 index 0000000000..3a3783dd57 --- /dev/null +++ b/litellm/proxy/pass_through_endpoints/common_utils.py @@ -0,0 +1,16 @@ +from fastapi import Request + + +def get_litellm_virtual_key(request: Request) -> str: + """ + Extract and format API key from request headers. + Prioritizes x-litellm-api-key over Authorization header. + + + Vertex JS SDK uses `Authorization` header, we use `x-litellm-api-key` to pass litellm virtual key + + """ + litellm_api_key = request.headers.get("x-litellm-api-key") + if litellm_api_key: + return f"Bearer {litellm_api_key}" + return request.headers.get("Authorization", "") diff --git a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py index 4724c7f9d1..be3a903dcc 100644 --- a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py @@ -12,10 +12,13 @@ import httpx from fastapi import APIRouter, Depends, HTTPException, Request, Response import litellm +from litellm._logging import verbose_proxy_logger from litellm.constants import BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES +from litellm.llms.vertex_ai.vertex_llm_base import VertexBase from litellm.proxy._types import * from litellm.proxy.auth.route_checks import RouteChecks from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.pass_through_endpoints.common_utils import get_litellm_virtual_key from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( create_pass_through_route, ) @@ -23,6 +26,7 @@ from litellm.secret_managers.main import get_secret_str from .passthrough_endpoint_router import PassthroughEndpointRouter +vertex_llm_base = VertexBase() router = APIRouter() default_vertex_config = None @@ -417,6 +421,135 @@ async def azure_proxy_route( ) +@router.api_route( + "/vertex-ai/{endpoint:path}", + methods=["GET", "POST", "PUT", "DELETE", "PATCH"], + tags=["Vertex AI Pass-through", "pass-through"], + include_in_schema=False, +) +@router.api_route( + "/vertex_ai/{endpoint:path}", + methods=["GET", "POST", "PUT", "DELETE", "PATCH"], + tags=["Vertex AI Pass-through", "pass-through"], +) +async def vertex_proxy_route( + endpoint: str, + request: Request, + fastapi_response: Response, +): + """ + Call LiteLLM proxy via Vertex AI SDK. + + [Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai) + """ + from litellm.llms.vertex_ai.common_utils import ( + construct_target_url, + get_vertex_location_from_url, + get_vertex_project_id_from_url, + ) + + encoded_endpoint = httpx.URL(endpoint).path + verbose_proxy_logger.debug("requested endpoint %s", endpoint) + headers: dict = {} + api_key_to_use = get_litellm_virtual_key(request=request) + user_api_key_dict = await user_api_key_auth( + request=request, + api_key=api_key_to_use, + ) + + vertex_project: Optional[str] = get_vertex_project_id_from_url(endpoint) + vertex_location: Optional[str] = get_vertex_location_from_url(endpoint) + vertex_credentials = passthrough_endpoint_router.get_vertex_credentials( + project_id=vertex_project, + location=vertex_location, + ) + + if vertex_credentials is None: + raise Exception( + f"No matching vertex credentials found, for project_id: {vertex_project}, location: {vertex_location}. No default_vertex_config set either." + ) + + # Use headers from the incoming request if no vertex credentials are found + if vertex_credentials.vertex_project is None: + headers = dict(request.headers) or {} + verbose_proxy_logger.debug( + "default_vertex_config not set, incoming request headers %s", headers + ) + base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/" + headers.pop("content-length", None) + headers.pop("host", None) + else: + vertex_project = vertex_credentials.vertex_project + vertex_location = vertex_credentials.vertex_location + vertex_credentials_str = vertex_credentials.vertex_credentials + + # Construct base URL for the target endpoint + base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/" + + _auth_header, vertex_project = await vertex_llm_base._ensure_access_token_async( + credentials=vertex_credentials_str, + project_id=vertex_project, + custom_llm_provider="vertex_ai_beta", + ) + + auth_header, _ = vertex_llm_base._get_token_and_url( + model="", + auth_header=_auth_header, + gemini_api_key=None, + vertex_credentials=vertex_credentials_str, + vertex_project=vertex_project, + vertex_location=vertex_location, + stream=False, + custom_llm_provider="vertex_ai_beta", + api_base="", + ) + + headers = { + "Authorization": f"Bearer {auth_header}", + } + + request_route = encoded_endpoint + verbose_proxy_logger.debug("request_route %s", request_route) + + # Ensure endpoint starts with '/' for proper URL construction + if not encoded_endpoint.startswith("/"): + encoded_endpoint = "/" + encoded_endpoint + + # Construct the full target URL using httpx + updated_url = construct_target_url( + base_url=base_target_url, + requested_route=encoded_endpoint, + default_vertex_location=vertex_location, + default_vertex_project=vertex_project, + ) + # base_url = httpx.URL(base_target_url) + # updated_url = base_url.copy_with(path=encoded_endpoint) + + verbose_proxy_logger.debug("updated url %s", updated_url) + + ## check for streaming + target = str(updated_url) + is_streaming_request = False + if "stream" in str(updated_url): + is_streaming_request = True + target += "?alt=sse" + + ## CREATE PASS-THROUGH + endpoint_func = create_pass_through_route( + endpoint=endpoint, + target=target, + custom_headers=headers, + ) # dynamically construct pass-through endpoint based on incoming path + received_value = await endpoint_func( + request, + fastapi_response, + user_api_key_dict, + stream=is_streaming_request, # type: ignore + ) + + return received_value + + @router.api_route( "/openai/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"], diff --git a/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py b/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py index adf7d0f30c..5267c3b26c 100644 --- a/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py +++ b/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py @@ -2,6 +2,7 @@ from typing import Dict, Optional from litellm._logging import verbose_logger from litellm.secret_managers.main import get_secret_str +from litellm.types.passthrough_endpoints.vertex_ai import VertexPassThroughCredentials class PassthroughEndpointRouter: @@ -11,6 +12,9 @@ class PassthroughEndpointRouter: def __init__(self): self.credentials: Dict[str, str] = {} + self.deployment_key_to_vertex_credentials: Dict[ + str, VertexPassThroughCredentials + ] = {} def set_pass_through_credentials( self, @@ -62,6 +66,38 @@ class PassthroughEndpointRouter: ) return get_secret_str(_env_variable_name) + def _get_deployment_key( + self, project_id: Optional[str], location: Optional[str] + ) -> Optional[str]: + """ + Get the deployment key for the given project-id, location + """ + if project_id is None or location is None: + return None + return f"{project_id}-{location}" + + def get_vertex_credentials( + self, project_id: Optional[str], location: Optional[str] + ) -> Optional[VertexPassThroughCredentials]: + """ + Get the vertex credentials for the given project-id, location + """ + # from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( + # default_vertex_config, + # ) + default_vertex_config: Optional[VertexPassThroughCredentials] = None + + deployment_key = self._get_deployment_key( + project_id=project_id, + location=location, + ) + if deployment_key is None: + return default_vertex_config + if deployment_key in self.deployment_key_to_vertex_credentials: + return self.deployment_key_to_vertex_credentials[deployment_key] + else: + return default_vertex_config + def _get_credential_name_for_provider( self, custom_llm_provider: str, diff --git a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py index 7444e3d1c1..4b706ed33a 100644 --- a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py @@ -1,274 +1,259 @@ -import traceback -from typing import Optional +# import traceback +# from typing import Optional -import httpx -from fastapi import APIRouter, HTTPException, Request, Response, status +# import httpx +# from fastapi import APIRouter, HTTPException, Request, Response, status -import litellm -from litellm._logging import verbose_proxy_logger -from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance -from litellm.proxy._types import * -from litellm.proxy.auth.user_api_key_auth import user_api_key_auth -from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( - create_pass_through_route, -) -from litellm.secret_managers.main import get_secret_str -from litellm.types.passthrough_endpoints.vertex_ai import * +# import litellm +# from litellm._logging import verbose_proxy_logger +# from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance +# from litellm.proxy._types import * +# from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +# from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( +# create_pass_through_route, +# ) +# from litellm.secret_managers.main import get_secret_str +# from litellm.types.passthrough_endpoints.vertex_ai import * -from .vertex_passthrough_router import VertexPassThroughRouter +# from .vertex_passthrough_router import VertexPassThroughRouter -router = APIRouter() -vertex_pass_through_router = VertexPassThroughRouter() +# router = APIRouter() +# vertex_pass_through_router = VertexPassThroughRouter() -default_vertex_config: VertexPassThroughCredentials = VertexPassThroughCredentials() +# default_vertex_config: Optional[VertexPassThroughCredentials] = None -def _get_vertex_env_vars() -> VertexPassThroughCredentials: - """ - Helper to get vertex pass through config from environment variables +# def _get_vertex_env_vars() -> VertexPassThroughCredentials: +# """ +# Helper to get vertex pass through config from environment variables - The following environment variables are used: - - DEFAULT_VERTEXAI_PROJECT (project id) - - DEFAULT_VERTEXAI_LOCATION (location) - - DEFAULT_GOOGLE_APPLICATION_CREDENTIALS (path to credentials file) - """ - return VertexPassThroughCredentials( - vertex_project=get_secret_str("DEFAULT_VERTEXAI_PROJECT"), - vertex_location=get_secret_str("DEFAULT_VERTEXAI_LOCATION"), - vertex_credentials=get_secret_str("DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"), - ) +# The following environment variables are used: +# - DEFAULT_VERTEXAI_PROJECT (project id) +# - DEFAULT_VERTEXAI_LOCATION (location) +# - DEFAULT_GOOGLE_APPLICATION_CREDENTIALS (path to credentials file) +# """ +# return VertexPassThroughCredentials( +# vertex_project=get_secret_str("DEFAULT_VERTEXAI_PROJECT"), +# vertex_location=get_secret_str("DEFAULT_VERTEXAI_LOCATION"), +# vertex_credentials=get_secret_str("DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"), +# ) -def set_default_vertex_config(config: Optional[dict] = None): - """Sets vertex configuration from provided config and/or environment variables +# def set_default_vertex_config(config: Optional[dict] = None): +# """Sets vertex configuration from provided config and/or environment variables - Args: - config (Optional[dict]): Configuration dictionary - Example: { - "vertex_project": "my-project-123", - "vertex_location": "us-central1", - "vertex_credentials": "os.environ/GOOGLE_CREDS" - } - """ - global default_vertex_config +# Args: +# config (Optional[dict]): Configuration dictionary +# Example: { +# "vertex_project": "my-project-123", +# "vertex_location": "us-central1", +# "vertex_credentials": "os.environ/GOOGLE_CREDS" +# } +# """ +# global default_vertex_config - # Initialize config dictionary if None - if config is None: - default_vertex_config = _get_vertex_env_vars() - return +# # Initialize config dictionary if None +# if config is None: +# default_vertex_config = _get_vertex_env_vars() +# return - if isinstance(config, dict): - for key, value in config.items(): - if isinstance(value, str) and value.startswith("os.environ/"): - config[key] = litellm.get_secret(value) +# if isinstance(config, dict): +# for key, value in config.items(): +# if isinstance(value, str) and value.startswith("os.environ/"): +# config[key] = litellm.get_secret(value) - _set_default_vertex_config(VertexPassThroughCredentials(**config)) +# _set_default_vertex_config(VertexPassThroughCredentials(**config)) -def _set_default_vertex_config( - vertex_pass_through_credentials: VertexPassThroughCredentials, -): - global default_vertex_config - default_vertex_config = vertex_pass_through_credentials +# def _set_default_vertex_config( +# vertex_pass_through_credentials: VertexPassThroughCredentials, +# ): +# global default_vertex_config +# default_vertex_config = vertex_pass_through_credentials -def exception_handler(e: Exception): - verbose_proxy_logger.error( - "litellm.proxy.proxy_server.v1/projects/tuningJobs(): Exception occurred - {}".format( - str(e) - ) - ) - verbose_proxy_logger.debug(traceback.format_exc()) - if isinstance(e, HTTPException): - return ProxyException( - message=getattr(e, "message", str(e.detail)), - type=getattr(e, "type", "None"), - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), - ) - else: - error_msg = f"{str(e)}" - return ProxyException( - message=getattr(e, "message", error_msg), - type=getattr(e, "type", "None"), - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", 500), - ) +# def exception_handler(e: Exception): +# verbose_proxy_logger.error( +# "litellm.proxy.proxy_server.v1/projects/tuningJobs(): Exception occurred - {}".format( +# str(e) +# ) +# ) +# verbose_proxy_logger.debug(traceback.format_exc()) +# if isinstance(e, HTTPException): +# return ProxyException( +# message=getattr(e, "message", str(e.detail)), +# type=getattr(e, "type", "None"), +# param=getattr(e, "param", "None"), +# code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), +# ) +# else: +# error_msg = f"{str(e)}" +# return ProxyException( +# message=getattr(e, "message", error_msg), +# type=getattr(e, "type", "None"), +# param=getattr(e, "param", "None"), +# code=getattr(e, "status_code", 500), +# ) -def construct_target_url( - base_url: str, - requested_route: str, - default_vertex_location: Optional[str], - default_vertex_project: Optional[str], -) -> httpx.URL: - """ - Allow user to specify their own project id / location. +# def construct_target_url( +# base_url: str, +# requested_route: str, +# default_vertex_location: Optional[str], +# default_vertex_project: Optional[str], +# ) -> httpx.URL: +# """ +# Allow user to specify their own project id / location. - If missing, use defaults +# If missing, use defaults - Handle cachedContent scenario - https://github.com/BerriAI/litellm/issues/5460 +# Handle cachedContent scenario - https://github.com/BerriAI/litellm/issues/5460 - Constructed Url: - POST https://LOCATION-aiplatform.googleapis.com/{version}/projects/PROJECT_ID/locations/LOCATION/cachedContents - """ - new_base_url = httpx.URL(base_url) - if "locations" in requested_route: # contains the target project id + location - updated_url = new_base_url.copy_with(path=requested_route) - return updated_url - """ - - Add endpoint version (e.g. v1beta for cachedContent, v1 for rest) - - Add default project id - - Add default location - """ - vertex_version: Literal["v1", "v1beta1"] = "v1" - if "cachedContent" in requested_route: - vertex_version = "v1beta1" +# Constructed Url: +# POST https://LOCATION-aiplatform.googleapis.com/{version}/projects/PROJECT_ID/locations/LOCATION/cachedContents +# """ +# new_base_url = httpx.URL(base_url) +# if "locations" in requested_route: # contains the target project id + location +# updated_url = new_base_url.copy_with(path=requested_route) +# return updated_url +# """ +# - Add endpoint version (e.g. v1beta for cachedContent, v1 for rest) +# - Add default project id +# - Add default location +# """ +# vertex_version: Literal["v1", "v1beta1"] = "v1" +# if "cachedContent" in requested_route: +# vertex_version = "v1beta1" - base_requested_route = "{}/projects/{}/locations/{}".format( - vertex_version, default_vertex_project, default_vertex_location - ) +# base_requested_route = "{}/projects/{}/locations/{}".format( +# vertex_version, default_vertex_project, default_vertex_location +# ) - updated_requested_route = "/" + base_requested_route + requested_route +# updated_requested_route = "/" + base_requested_route + requested_route - updated_url = new_base_url.copy_with(path=updated_requested_route) - return updated_url +# updated_url = new_base_url.copy_with(path=updated_requested_route) +# return updated_url -@router.api_route( - "/vertex-ai/{endpoint:path}", - methods=["GET", "POST", "PUT", "DELETE", "PATCH"], - tags=["Vertex AI Pass-through", "pass-through"], - include_in_schema=False, -) -@router.api_route( - "/vertex_ai/{endpoint:path}", - methods=["GET", "POST", "PUT", "DELETE", "PATCH"], - tags=["Vertex AI Pass-through", "pass-through"], -) -async def vertex_proxy_route( - endpoint: str, - request: Request, - fastapi_response: Response, -): - """ - Call LiteLLM proxy via Vertex AI SDK. +# @router.api_route( +# "/vertex-ai/{endpoint:path}", +# methods=["GET", "POST", "PUT", "DELETE", "PATCH"], +# tags=["Vertex AI Pass-through", "pass-through"], +# include_in_schema=False, +# ) +# @router.api_route( +# "/vertex_ai/{endpoint:path}", +# methods=["GET", "POST", "PUT", "DELETE", "PATCH"], +# tags=["Vertex AI Pass-through", "pass-through"], +# ) +# async def vertex_proxy_route( +# endpoint: str, +# request: Request, +# fastapi_response: Response, +# ): +# """ +# Call LiteLLM proxy via Vertex AI SDK. - [Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai) - """ - encoded_endpoint = httpx.URL(endpoint).path - verbose_proxy_logger.debug("requested endpoint %s", endpoint) - headers: dict = {} - api_key_to_use = get_litellm_virtual_key(request=request) - user_api_key_dict = await user_api_key_auth( - request=request, - api_key=api_key_to_use, - ) +# [Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai) +# """ +# encoded_endpoint = httpx.URL(endpoint).path +# verbose_proxy_logger.debug("requested endpoint %s", endpoint) +# headers: dict = {} +# api_key_to_use = get_litellm_virtual_key(request=request) +# user_api_key_dict = await user_api_key_auth( +# request=request, +# api_key=api_key_to_use, +# ) - vertex_project: Optional[str] = ( - VertexPassThroughRouter._get_vertex_project_id_from_url(endpoint) - ) - vertex_location: Optional[str] = ( - VertexPassThroughRouter._get_vertex_location_from_url(endpoint) - ) - vertex_credentials = vertex_pass_through_router.get_vertex_credentials( - project_id=vertex_project, - location=vertex_location, - ) +# vertex_project: Optional[str] = ( +# VertexPassThroughRouter._get_vertex_project_id_from_url(endpoint) +# ) +# vertex_location: Optional[str] = ( +# VertexPassThroughRouter._get_vertex_location_from_url(endpoint) +# ) +# vertex_credentials = vertex_pass_through_router.get_vertex_credentials( +# project_id=vertex_project, +# location=vertex_location, +# ) - # Use headers from the incoming request if no vertex credentials are found - if vertex_credentials.vertex_project is None: - headers = dict(request.headers) or {} - verbose_proxy_logger.debug( - "default_vertex_config not set, incoming request headers %s", headers - ) - base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/" - headers.pop("content-length", None) - headers.pop("host", None) - else: - vertex_project = vertex_credentials.vertex_project - vertex_location = vertex_credentials.vertex_location - vertex_credentials_str = vertex_credentials.vertex_credentials +# # Use headers from the incoming request if no vertex credentials are found +# if vertex_credentials.vertex_project is None: +# headers = dict(request.headers) or {} +# verbose_proxy_logger.debug( +# "default_vertex_config not set, incoming request headers %s", headers +# ) +# base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/" +# headers.pop("content-length", None) +# headers.pop("host", None) +# else: +# vertex_project = vertex_credentials.vertex_project +# vertex_location = vertex_credentials.vertex_location +# vertex_credentials_str = vertex_credentials.vertex_credentials - # Construct base URL for the target endpoint - base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/" +# # Construct base URL for the target endpoint +# base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/" - _auth_header, vertex_project = ( - await vertex_fine_tuning_apis_instance._ensure_access_token_async( - credentials=vertex_credentials_str, - project_id=vertex_project, - custom_llm_provider="vertex_ai_beta", - ) - ) +# _auth_header, vertex_project = ( +# await vertex_fine_tuning_apis_instance._ensure_access_token_async( +# credentials=vertex_credentials_str, +# project_id=vertex_project, +# custom_llm_provider="vertex_ai_beta", +# ) +# ) - auth_header, _ = vertex_fine_tuning_apis_instance._get_token_and_url( - model="", - auth_header=_auth_header, - gemini_api_key=None, - vertex_credentials=vertex_credentials_str, - vertex_project=vertex_project, - vertex_location=vertex_location, - stream=False, - custom_llm_provider="vertex_ai_beta", - api_base="", - ) +# auth_header, _ = vertex_fine_tuning_apis_instance._get_token_and_url( +# model="", +# auth_header=_auth_header, +# gemini_api_key=None, +# vertex_credentials=vertex_credentials_str, +# vertex_project=vertex_project, +# vertex_location=vertex_location, +# stream=False, +# custom_llm_provider="vertex_ai_beta", +# api_base="", +# ) - headers = { - "Authorization": f"Bearer {auth_header}", - } +# headers = { +# "Authorization": f"Bearer {auth_header}", +# } - request_route = encoded_endpoint - verbose_proxy_logger.debug("request_route %s", request_route) +# request_route = encoded_endpoint +# verbose_proxy_logger.debug("request_route %s", request_route) - # Ensure endpoint starts with '/' for proper URL construction - if not encoded_endpoint.startswith("/"): - encoded_endpoint = "/" + encoded_endpoint +# # Ensure endpoint starts with '/' for proper URL construction +# if not encoded_endpoint.startswith("/"): +# encoded_endpoint = "/" + encoded_endpoint - # Construct the full target URL using httpx - updated_url = construct_target_url( - base_url=base_target_url, - requested_route=encoded_endpoint, - default_vertex_location=vertex_location, - default_vertex_project=vertex_project, - ) - # base_url = httpx.URL(base_target_url) - # updated_url = base_url.copy_with(path=encoded_endpoint) +# # Construct the full target URL using httpx +# updated_url = construct_target_url( +# base_url=base_target_url, +# requested_route=encoded_endpoint, +# default_vertex_location=vertex_location, +# default_vertex_project=vertex_project, +# ) +# # base_url = httpx.URL(base_target_url) +# # updated_url = base_url.copy_with(path=encoded_endpoint) - verbose_proxy_logger.debug("updated url %s", updated_url) +# verbose_proxy_logger.debug("updated url %s", updated_url) - ## check for streaming - target = str(updated_url) - is_streaming_request = False - if "stream" in str(updated_url): - is_streaming_request = True - target += "?alt=sse" +# ## check for streaming +# target = str(updated_url) +# is_streaming_request = False +# if "stream" in str(updated_url): +# is_streaming_request = True +# target += "?alt=sse" - ## CREATE PASS-THROUGH - endpoint_func = create_pass_through_route( - endpoint=endpoint, - target=target, - custom_headers=headers, - ) # dynamically construct pass-through endpoint based on incoming path - received_value = await endpoint_func( - request, - fastapi_response, - user_api_key_dict, - stream=is_streaming_request, # type: ignore - ) +# ## CREATE PASS-THROUGH +# endpoint_func = create_pass_through_route( +# endpoint=endpoint, +# target=target, +# custom_headers=headers, +# ) # dynamically construct pass-through endpoint based on incoming path +# received_value = await endpoint_func( +# request, +# fastapi_response, +# user_api_key_dict, +# stream=is_streaming_request, # type: ignore +# ) - return received_value - - -def get_litellm_virtual_key(request: Request) -> str: - """ - Extract and format API key from request headers. - Prioritizes x-litellm-api-key over Authorization header. - - - Vertex JS SDK uses `Authorization` header, we use `x-litellm-api-key` to pass litellm virtual key - - """ - litellm_api_key = request.headers.get("x-litellm-api-key") - if litellm_api_key: - return f"Bearer {litellm_api_key}" - return request.headers.get("Authorization", "") +# return received_value diff --git a/litellm/proxy/vertex_ai_endpoints/vertex_passthrough_router.py b/litellm/proxy/vertex_ai_endpoints/vertex_passthrough_router.py index 0273a62047..5017a8f661 100644 --- a/litellm/proxy/vertex_ai_endpoints/vertex_passthrough_router.py +++ b/litellm/proxy/vertex_ai_endpoints/vertex_passthrough_router.py @@ -1,121 +1,108 @@ -import json -import re -from typing import Dict, Optional +# import json +# import re +# from typing import Dict, Optional -from litellm._logging import verbose_proxy_logger -from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( - VertexPassThroughCredentials, -) -from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES +# from litellm._logging import verbose_proxy_logger +# from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( +# VertexPassThroughCredentials, +# ) +# from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES -class VertexPassThroughRouter: - """ - Vertex Pass Through Router for Vertex AI pass-through endpoints +# class VertexPassThroughRouter: +# """ +# Vertex Pass Through Router for Vertex AI pass-through endpoints - - if request specifies a project-id, location -> use credentials corresponding to the project-id, location - - if request does not specify a project-id, location -> use credentials corresponding to the DEFAULT_VERTEXAI_PROJECT, DEFAULT_VERTEXAI_LOCATION - """ +# - if request specifies a project-id, location -> use credentials corresponding to the project-id, location +# - if request does not specify a project-id, location -> use credentials corresponding to the DEFAULT_VERTEXAI_PROJECT, DEFAULT_VERTEXAI_LOCATION +# """ - def __init__(self): - """ - Initialize the VertexPassThroughRouter - Stores the vertex credentials for each deployment key - ``` - { - "project_id-location": VertexPassThroughCredentials, - "adroit-crow-us-central1": VertexPassThroughCredentials, - } - ``` - """ - self.deployment_key_to_vertex_credentials: Dict[ - str, VertexPassThroughCredentials - ] = {} - pass +# def __init__(self): +# """ +# Initialize the VertexPassThroughRouter +# Stores the vertex credentials for each deployment key +# ``` +# { +# "project_id-location": VertexPassThroughCredentials, +# "adroit-crow-us-central1": VertexPassThroughCredentials, +# } +# ``` +# """ +# self.deployment_key_to_vertex_credentials: Dict[ +# str, VertexPassThroughCredentials +# ] = {} +# pass - def get_vertex_credentials( - self, project_id: Optional[str], location: Optional[str] - ) -> VertexPassThroughCredentials: - """ - Get the vertex credentials for the given project-id, location - """ - from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( - default_vertex_config, - ) +# def get_vertex_credentials( +# self, project_id: Optional[str], location: Optional[str] +# ) -> Optional[VertexPassThroughCredentials]: +# """ +# Get the vertex credentials for the given project-id, location +# """ +# from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( +# default_vertex_config, +# ) - deployment_key = self._get_deployment_key( - project_id=project_id, - location=location, - ) - if deployment_key is None: - return default_vertex_config - if deployment_key in self.deployment_key_to_vertex_credentials: - return self.deployment_key_to_vertex_credentials[deployment_key] - else: - return default_vertex_config +# deployment_key = self._get_deployment_key( +# project_id=project_id, +# location=location, +# ) +# if deployment_key is None: +# return default_vertex_config +# if deployment_key in self.deployment_key_to_vertex_credentials: +# return self.deployment_key_to_vertex_credentials[deployment_key] +# else: +# return default_vertex_config - def add_vertex_credentials( - self, - project_id: str, - location: str, - vertex_credentials: VERTEX_CREDENTIALS_TYPES, - ): - """ - Add the vertex credentials for the given project-id, location - """ - from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( - _set_default_vertex_config, - ) +# def add_vertex_credentials( +# self, +# project_id: str, +# location: str, +# vertex_credentials: VERTEX_CREDENTIALS_TYPES, +# ): +# """ +# Add the vertex credentials for the given project-id, location +# """ - deployment_key = self._get_deployment_key( - project_id=project_id, - location=location, - ) - if deployment_key is None: - verbose_proxy_logger.debug( - "No deployment key found for project-id, location" - ) - return - vertex_pass_through_credentials = VertexPassThroughCredentials( - vertex_project=project_id, - vertex_location=location, - vertex_credentials=vertex_credentials, - ) - self.deployment_key_to_vertex_credentials[deployment_key] = ( - vertex_pass_through_credentials - ) - verbose_proxy_logger.debug( - f"self.deployment_key_to_vertex_credentials: {json.dumps(self.deployment_key_to_vertex_credentials, indent=4, default=str)}" - ) - _set_default_vertex_config(vertex_pass_through_credentials) +# deployment_key = self._get_deployment_key( +# project_id=project_id, +# location=location, +# ) +# if deployment_key is None: +# verbose_proxy_logger.debug( +# "No deployment key found for project-id, location" +# ) +# return +# vertex_pass_through_credentials = VertexPassThroughCredentials( +# vertex_project=project_id, +# vertex_location=location, +# vertex_credentials=vertex_credentials, +# ) +# self.deployment_key_to_vertex_credentials[deployment_key] = ( +# vertex_pass_through_credentials +# ) +# verbose_proxy_logger.debug( +# f"self.deployment_key_to_vertex_credentials: {json.dumps(self.deployment_key_to_vertex_credentials, indent=4, default=str)}" +# ) - def _get_deployment_key( - self, project_id: Optional[str], location: Optional[str] - ) -> Optional[str]: - """ - Get the deployment key for the given project-id, location - """ - if project_id is None or location is None: - return None - return f"{project_id}-{location}" - @staticmethod - def _get_vertex_project_id_from_url(url: str) -> Optional[str]: - """ - Get the vertex project id from the url +# @staticmethod +# def _get_vertex_project_id_from_url(url: str) -> Optional[str]: +# """ +# Get the vertex project id from the url - `https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent` - """ - match = re.search(r"/projects/([^/]+)", url) - return match.group(1) if match else None +# `https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent` +# """ +# match = re.search(r"/projects/([^/]+)", url) +# return match.group(1) if match else None - @staticmethod - def _get_vertex_location_from_url(url: str) -> Optional[str]: - """ - Get the vertex location from the url +# @staticmethod +# def _get_vertex_location_from_url(url: str) -> Optional[str]: +# """ +# Get the vertex location from the url - `https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent` - """ - match = re.search(r"/locations/([^/]+)", url) - return match.group(1) if match else None +# `https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent` +# """ +# match = re.search(r"/locations/([^/]+)", url) +# return match.group(1) if match else None diff --git a/tests/litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py b/tests/litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py index 2f5ce85de7..48cb60968b 100644 --- a/tests/litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py +++ b/tests/litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py @@ -1,6 +1,7 @@ import json import os import sys +from unittest import mock from unittest.mock import MagicMock, patch import httpx @@ -17,7 +18,9 @@ from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import ( BaseOpenAIPassThroughHandler, RouteChecks, create_pass_through_route, + vertex_proxy_route, ) +from litellm.types.passthrough_endpoints.vertex_ai import VertexPassThroughCredentials class TestBaseOpenAIPassThroughHandler: @@ -176,3 +179,72 @@ class TestBaseOpenAIPassThroughHandler: print(f"query_params: {call_kwargs['query_params']}") assert call_kwargs["stream"] is False assert call_kwargs["query_params"] == {"model": "gpt-4"} + + +class TestVertexAIPassThroughHandler: + """ + Case 1: User set passthrough credentials - confirm credentials used. + + Case 2: User set default credentials, no exact passthrough credentials - confirm default credentials used. + + Case 3: No default credentials, incorrect project/base passed - confirm no credentials used. + """ + + @pytest.mark.asyncio + async def test_vertex_passthrough_with_credentials(self): + """ + Test that when passthrough credentials are set, they are correctly used in the request + """ + # Mock request + mock_request = Request( + scope={ + "type": "http", + "method": "POST", + "path": "/vertex_ai/models/test-model/predict", + "headers": {}, + } + ) + + # Mock response + mock_response = Response() + + # Mock vertex credentials + test_project = "test-project" + test_location = "us-central1" + test_token = "test-token-123" + + with mock.patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router.get_vertex_credentials" + ) as mock_get_creds, mock.patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._ensure_access_token_async" + ) as mock_ensure_token, mock.patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._get_token_and_url" + ) as mock_get_token, mock.patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route" + ) as mock_create_route: + + # Setup mock returns + mock_get_creds.return_value = VertexPassThroughCredentials( + vertex_project=test_project, + vertex_location=test_location, + vertex_credentials="test-creds", + ) + mock_ensure_token.return_value = ("test-auth-header", test_project) + mock_get_token.return_value = (test_token, "") + + # Call the route + try: + await vertex_proxy_route( + endpoint="models/test-model/predict", + request=mock_request, + fastapi_response=mock_response, + ) + except Exception as e: + print(f"Error: {e}") + + # Verify create_pass_through_route was called with correct arguments + mock_create_route.assert_called_once_with( + endpoint="models/test-model/predict", + target=f"https://{test_location}-aiplatform.googleapis.com/v1/projects/{test_project}/locations/{test_location}/models/test-model/predict", + custom_headers={"Authorization": f"Bearer {test_token}"}, + ) diff --git a/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py b/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py index ba5dfa33a8..9b354a84c9 100644 --- a/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py +++ b/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py @@ -218,6 +218,29 @@ async def test_get_vertex_credentials_stored(): assert creds.vertex_credentials == '{"credentials": "test-creds"}' +@pytest.mark.asyncio +async def test_default_credentials(): + """ + Test get_vertex_credentials with stored credentials. + + Tests if default credentials are used if set. + + Tests if no default credentials are used, if no default set + """ + router = VertexPassThroughRouter() + router.add_vertex_credentials( + project_id="test-project", + location="us-central1", + vertex_credentials='{"credentials": "test-creds"}', + ) + + creds = router.get_vertex_credentials( + project_id="test-project", location="us-central2" + ) + + assert creds is None + + @pytest.mark.asyncio async def test_add_vertex_credentials(): """Test add_vertex_credentials functionality"""