refactor(llm_passthrough_endpoints.py): use common _base_vertex_proxy_route

Prevents duplicate code
This commit is contained in:
Krrish Dholakia 2025-04-16 18:56:36 -07:00
parent ba38045637
commit 39936c0ce8

View file

@ -578,6 +578,124 @@ async def azure_proxy_route(
)
async def _base_vertex_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
base_target_url: str,
user_api_key_dict: Optional[UserAPIKeyAuth] = None,
):
"""
Base function for Vertex AI passthrough routes.
Handles common logic for all Vertex AI services.
"""
from litellm.llms.vertex_ai.common_utils import (
construct_target_url,
get_vertex_location_from_url,
get_vertex_project_id_from_url,
)
encoded_endpoint = httpx.URL(endpoint).path
verbose_proxy_logger.debug("requested endpoint %s", endpoint)
headers: dict = {}
if user_api_key_dict is None:
api_key_to_use = get_litellm_virtual_key(request=request)
user_api_key_dict = await user_api_key_auth(
request=request,
api_key=api_key_to_use,
)
vertex_project: Optional[str] = get_vertex_project_id_from_url(endpoint)
vertex_location: Optional[str] = get_vertex_location_from_url(endpoint)
vertex_credentials = passthrough_endpoint_router.get_vertex_credentials(
project_id=vertex_project,
location=vertex_location,
)
headers_passed_through = False
# Use headers from the incoming request if no vertex credentials are found
if vertex_credentials is None or vertex_credentials.vertex_project is None:
headers = dict(request.headers) or {}
headers_passed_through = True
verbose_proxy_logger.debug(
"default_vertex_config not set, incoming request headers %s", headers
)
headers.pop("content-length", None)
headers.pop("host", None)
else:
vertex_project = vertex_credentials.vertex_project
vertex_location = vertex_credentials.vertex_location
vertex_credentials_str = vertex_credentials.vertex_credentials
_auth_header, vertex_project = await vertex_llm_base._ensure_access_token_async(
credentials=vertex_credentials_str,
project_id=vertex_project,
custom_llm_provider="vertex_ai_beta",
)
auth_header, _ = vertex_llm_base._get_token_and_url(
model="",
auth_header=_auth_header,
gemini_api_key=None,
vertex_credentials=vertex_credentials_str,
vertex_project=vertex_project,
vertex_location=vertex_location,
stream=False,
custom_llm_provider="vertex_ai_beta",
api_base="",
)
headers = {
"Authorization": f"Bearer {auth_header}",
}
request_route = encoded_endpoint
verbose_proxy_logger.debug("request_route %s", request_route)
# Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL using httpx
updated_url = construct_target_url(
base_url=base_target_url,
requested_route=encoded_endpoint,
vertex_location=vertex_location,
vertex_project=vertex_project,
)
verbose_proxy_logger.debug("updated url %s", updated_url)
## check for streaming
target = str(updated_url)
is_streaming_request = False
if "stream" in str(updated_url):
is_streaming_request = True
target += "?alt=sse"
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route(
endpoint=endpoint,
target=target,
custom_headers=headers,
) # dynamically construct pass-through endpoint based on incoming path
try:
received_value = await endpoint_func(
request,
fastapi_response,
user_api_key_dict,
stream=is_streaming_request, # type: ignore
)
except ProxyException as e:
if headers_passed_through:
e.message = f"No credentials found on proxy for project_name={vertex_project} + location={vertex_location}, check `/model/info` for allowed project + region combinations with `use_in_pass_through: true`. Headers were passed through directly but request failed with error: {e.message}"
raise e
return received_value
@router.api_route(
"/vertex_ai/discovery/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
@ -595,112 +713,12 @@ async def vertex_discovery_proxy_route(
Target url: `https://discoveryengine.googleapis.com`
"""
from litellm.llms.vertex_ai.common_utils import (
construct_target_url,
get_vertex_location_from_url,
get_vertex_project_id_from_url,
)
encoded_endpoint = httpx.URL(endpoint).path
verbose_proxy_logger.debug("requested endpoint %s", endpoint)
headers: dict = {}
api_key_to_use = get_litellm_virtual_key(request=request)
user_api_key_dict = await user_api_key_auth(
request=request,
api_key=api_key_to_use,
)
vertex_project: Optional[str] = get_vertex_project_id_from_url(endpoint)
vertex_location: Optional[str] = get_vertex_location_from_url(endpoint)
vertex_credentials = passthrough_endpoint_router.get_vertex_credentials(
project_id=vertex_project,
location=vertex_location,
)
headers_passed_through = False
# Use headers from the incoming request if no vertex credentials are found
if vertex_credentials is None or vertex_credentials.vertex_project is None:
headers = dict(request.headers) or {}
headers_passed_through = True
verbose_proxy_logger.debug(
"default_vertex_config not set, incoming request headers %s", headers
)
base_target_url = "https://discoveryengine.googleapis.com/"
headers.pop("content-length", None)
headers.pop("host", None)
else:
vertex_project = vertex_credentials.vertex_project
vertex_location = vertex_credentials.vertex_location
vertex_credentials_str = vertex_credentials.vertex_credentials
# Construct base URL for the target endpoint
base_target_url = "https://discoveryengine.googleapis.com/"
_auth_header, vertex_project = await vertex_llm_base._ensure_access_token_async(
credentials=vertex_credentials_str,
project_id=vertex_project,
custom_llm_provider="vertex_ai_beta",
)
auth_header, _ = vertex_llm_base._get_token_and_url(
model="",
auth_header=_auth_header,
gemini_api_key=None,
vertex_credentials=vertex_credentials_str,
vertex_project=vertex_project,
vertex_location=vertex_location,
stream=False,
custom_llm_provider="vertex_ai_beta",
api_base="",
)
headers = {
"Authorization": f"Bearer {auth_header}",
}
request_route = encoded_endpoint
verbose_proxy_logger.debug("request_route %s", request_route)
# Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL using httpx
updated_url = construct_target_url(
base_url=base_target_url,
requested_route=encoded_endpoint,
vertex_location=vertex_location,
vertex_project=vertex_project,
)
verbose_proxy_logger.debug("updated url %s", updated_url)
## check for streaming
target = str(updated_url)
is_streaming_request = False
if "stream" in str(updated_url):
is_streaming_request = True
target += "?alt=sse"
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route(
return await _base_vertex_proxy_route(
endpoint=endpoint,
target=target,
custom_headers=headers,
) # dynamically construct pass-through endpoint based on incoming path
try:
received_value = await endpoint_func(
request,
fastapi_response,
user_api_key_dict,
stream=is_streaming_request, # type: ignore
)
except ProxyException as e:
if headers_passed_through:
e.message = f"No credentials found on proxy for project_name={vertex_project} + location={vertex_location}, check `/model/info` for allowed project + region combinations with `use_in_pass_through: true`. Headers were passed through directly but request failed with error: {e.message}"
raise e
return received_value
request=request,
fastapi_response=fastapi_response,
base_target_url="https://discoveryengine.googleapis.com/",
)
@router.api_route(
@ -718,118 +736,25 @@ async def vertex_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Call LiteLLM proxy via Vertex AI SDK.
[Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai)
"""
from litellm.llms.vertex_ai.common_utils import (
construct_target_url,
get_vertex_location_from_url,
get_vertex_project_id_from_url,
)
from litellm.llms.vertex_ai.common_utils import get_vertex_location_from_url
encoded_endpoint = httpx.URL(endpoint).path
verbose_proxy_logger.debug("requested endpoint %s", endpoint)
headers: dict = {}
api_key_to_use = get_litellm_virtual_key(request=request)
user_api_key_dict = await user_api_key_auth(
request=request,
api_key=api_key_to_use,
)
vertex_project: Optional[str] = get_vertex_project_id_from_url(endpoint)
vertex_location: Optional[str] = get_vertex_location_from_url(endpoint)
vertex_credentials = passthrough_endpoint_router.get_vertex_credentials(
project_id=vertex_project,
location=vertex_location,
)
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
headers_passed_through = False
# Use headers from the incoming request if no vertex credentials are found
if vertex_credentials is None or vertex_credentials.vertex_project is None:
headers = dict(request.headers) or {}
headers_passed_through = True
verbose_proxy_logger.debug(
"default_vertex_config not set, incoming request headers %s", headers
)
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
headers.pop("content-length", None)
headers.pop("host", None)
else:
vertex_project = vertex_credentials.vertex_project
vertex_location = vertex_credentials.vertex_location
vertex_credentials_str = vertex_credentials.vertex_credentials
# Construct base URL for the target endpoint
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
_auth_header, vertex_project = await vertex_llm_base._ensure_access_token_async(
credentials=vertex_credentials_str,
project_id=vertex_project,
custom_llm_provider="vertex_ai_beta",
)
auth_header, _ = vertex_llm_base._get_token_and_url(
model="",
auth_header=_auth_header,
gemini_api_key=None,
vertex_credentials=vertex_credentials_str,
vertex_project=vertex_project,
vertex_location=vertex_location,
stream=False,
custom_llm_provider="vertex_ai_beta",
api_base="",
)
headers = {
"Authorization": f"Bearer {auth_header}",
}
request_route = encoded_endpoint
verbose_proxy_logger.debug("request_route %s", request_route)
# Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL using httpx
updated_url = construct_target_url(
base_url=base_target_url,
requested_route=encoded_endpoint,
vertex_location=vertex_location,
vertex_project=vertex_project,
)
verbose_proxy_logger.debug("updated url %s", updated_url)
## check for streaming
target = str(updated_url)
is_streaming_request = False
if "stream" in str(updated_url):
is_streaming_request = True
target += "?alt=sse"
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route(
return await _base_vertex_proxy_route(
endpoint=endpoint,
target=target,
custom_headers=headers,
) # dynamically construct pass-through endpoint based on incoming path
try:
received_value = await endpoint_func(
request,
fastapi_response,
user_api_key_dict,
stream=is_streaming_request, # type: ignore
)
except ProxyException as e:
if headers_passed_through:
e.message = f"No credentials found on proxy for project_name={vertex_project} + location={vertex_location}, check `/model/info` for allowed project + region combinations with `use_in_pass_through: true`. Headers were passed through directly but request failed with error: {e.message}"
raise e
return received_value
request=request,
fastapi_response=fastapi_response,
base_target_url=base_target_url,
user_api_key_dict=user_api_key_dict,
)
@router.api_route(