refactor(llm_passthrough_endpoints.py): refactor vertex passthrough to use common llm passthrough handler.py

This commit is contained in:
Krrish Dholakia 2025-03-22 10:42:46 -07:00
parent 6bc6859224
commit 94d3413335
8 changed files with 650 additions and 338 deletions

View file

@ -12,10 +12,13 @@ import httpx
from fastapi import APIRouter, Depends, HTTPException, Request, Response
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.constants import BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
from litellm.proxy._types import *
from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.pass_through_endpoints.common_utils import get_litellm_virtual_key
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
create_pass_through_route,
)
@ -23,6 +26,7 @@ from litellm.secret_managers.main import get_secret_str
from .passthrough_endpoint_router import PassthroughEndpointRouter
vertex_llm_base = VertexBase()
router = APIRouter()
default_vertex_config = None
@ -417,6 +421,135 @@ async def azure_proxy_route(
)
@router.api_route(
"/vertex-ai/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["Vertex AI Pass-through", "pass-through"],
include_in_schema=False,
)
@router.api_route(
"/vertex_ai/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["Vertex AI Pass-through", "pass-through"],
)
async def vertex_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
):
"""
Call LiteLLM proxy via Vertex AI SDK.
[Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai)
"""
from litellm.llms.vertex_ai.common_utils import (
construct_target_url,
get_vertex_location_from_url,
get_vertex_project_id_from_url,
)
encoded_endpoint = httpx.URL(endpoint).path
verbose_proxy_logger.debug("requested endpoint %s", endpoint)
headers: dict = {}
api_key_to_use = get_litellm_virtual_key(request=request)
user_api_key_dict = await user_api_key_auth(
request=request,
api_key=api_key_to_use,
)
vertex_project: Optional[str] = get_vertex_project_id_from_url(endpoint)
vertex_location: Optional[str] = get_vertex_location_from_url(endpoint)
vertex_credentials = passthrough_endpoint_router.get_vertex_credentials(
project_id=vertex_project,
location=vertex_location,
)
if vertex_credentials is None:
raise Exception(
f"No matching vertex credentials found, for project_id: {vertex_project}, location: {vertex_location}. No default_vertex_config set either."
)
# Use headers from the incoming request if no vertex credentials are found
if vertex_credentials.vertex_project is None:
headers = dict(request.headers) or {}
verbose_proxy_logger.debug(
"default_vertex_config not set, incoming request headers %s", headers
)
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
headers.pop("content-length", None)
headers.pop("host", None)
else:
vertex_project = vertex_credentials.vertex_project
vertex_location = vertex_credentials.vertex_location
vertex_credentials_str = vertex_credentials.vertex_credentials
# Construct base URL for the target endpoint
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
_auth_header, vertex_project = await vertex_llm_base._ensure_access_token_async(
credentials=vertex_credentials_str,
project_id=vertex_project,
custom_llm_provider="vertex_ai_beta",
)
auth_header, _ = vertex_llm_base._get_token_and_url(
model="",
auth_header=_auth_header,
gemini_api_key=None,
vertex_credentials=vertex_credentials_str,
vertex_project=vertex_project,
vertex_location=vertex_location,
stream=False,
custom_llm_provider="vertex_ai_beta",
api_base="",
)
headers = {
"Authorization": f"Bearer {auth_header}",
}
request_route = encoded_endpoint
verbose_proxy_logger.debug("request_route %s", request_route)
# Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL using httpx
updated_url = construct_target_url(
base_url=base_target_url,
requested_route=encoded_endpoint,
default_vertex_location=vertex_location,
default_vertex_project=vertex_project,
)
# base_url = httpx.URL(base_target_url)
# updated_url = base_url.copy_with(path=encoded_endpoint)
verbose_proxy_logger.debug("updated url %s", updated_url)
## check for streaming
target = str(updated_url)
is_streaming_request = False
if "stream" in str(updated_url):
is_streaming_request = True
target += "?alt=sse"
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route(
endpoint=endpoint,
target=target,
custom_headers=headers,
) # dynamically construct pass-through endpoint based on incoming path
received_value = await endpoint_func(
request,
fastapi_response,
user_api_key_dict,
stream=is_streaming_request, # type: ignore
)
return received_value
@router.api_route(
"/openai/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],