(Feat) pass through vertex - allow using credentials defined on litellm router for vertex pass through (#8100)

* test_add_vertex_pass_through_deployment

* VertexPassThroughRouter

* fix use_in_pass_through

* VertexPassThroughRouter

* fix vertex_credentials

* allow using _initialize_deployment_for_pass_through

* test_add_vertex_pass_through_deployment

* _set_default_vertex_config

* fix verbose_proxy_logger

* fix use_in_pass_through

* fix _get_token_and_url

* test_get_vertex_location_from_url

* test_get_vertex_credentials_none

* run pt unit testing again

* fix add_vertex_credentials

* test_adding_deployments.py

* rename file
This commit is contained in:
Ishaan Jaff 2025-01-29 17:54:02 -08:00 committed by GitHub
parent 892581ffc3
commit b6d61ec22b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 490 additions and 19 deletions

View file

@ -15,7 +15,10 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
from litellm.secret_managers.main import get_secret_str
from litellm.types.passthrough_endpoints.vertex_ai import *
from .vertex_passthrough_router import VertexPassThroughRouter
router = APIRouter()
vertex_pass_through_router = VertexPassThroughRouter()
default_vertex_config: VertexPassThroughCredentials = VertexPassThroughCredentials()
@ -59,7 +62,14 @@ def set_default_vertex_config(config: Optional[dict] = None):
if isinstance(value, str) and value.startswith("os.environ/"):
config[key] = litellm.get_secret(value)
default_vertex_config = VertexPassThroughCredentials(**config)
_set_default_vertex_config(VertexPassThroughCredentials(**config))
def _set_default_vertex_config(
vertex_pass_through_credentials: VertexPassThroughCredentials,
):
global default_vertex_config
default_vertex_config = vertex_pass_through_credentials
def exception_handler(e: Exception):
@ -147,9 +157,6 @@ async def vertex_proxy_route(
[Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai)
"""
encoded_endpoint = httpx.URL(endpoint).path
import re
verbose_proxy_logger.debug("requested endpoint %s", endpoint)
headers: dict = {}
api_key_to_use = get_litellm_virtual_key(request=request)
@ -158,31 +165,37 @@ async def vertex_proxy_route(
api_key=api_key_to_use,
)
vertex_project = None
vertex_location = None
# Use headers from the incoming request if default_vertex_config is not set
if default_vertex_config.vertex_project is None:
vertex_project: Optional[str] = (
VertexPassThroughRouter._get_vertex_project_id_from_url(endpoint)
)
vertex_location: Optional[str] = (
VertexPassThroughRouter._get_vertex_location_from_url(endpoint)
)
vertex_credentials = vertex_pass_through_router.get_vertex_credentials(
project_id=vertex_project,
location=vertex_location,
)
# Use headers from the incoming request if no vertex credentials are found
if vertex_credentials.vertex_project is None:
headers = dict(request.headers) or {}
verbose_proxy_logger.debug(
"default_vertex_config not set, incoming request headers %s", headers
)
# extract location from endpoint, endpoint
# "v1beta1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.5-pro:generateContent"
match = re.search(r"/locations/([^/]+)", endpoint)
vertex_location = match.group(1) if match else None
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
headers.pop("content-length", None)
headers.pop("host", None)
else:
vertex_project = default_vertex_config.vertex_project
vertex_location = default_vertex_config.vertex_location
vertex_credentials = default_vertex_config.vertex_credentials
vertex_project = vertex_credentials.vertex_project
vertex_location = vertex_credentials.vertex_location
vertex_credentials_str = vertex_credentials.vertex_credentials
# Construct base URL for the target endpoint
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
_auth_header, vertex_project = (
await vertex_fine_tuning_apis_instance._ensure_access_token_async(
credentials=vertex_credentials,
credentials=vertex_credentials_str,
project_id=vertex_project,
custom_llm_provider="vertex_ai_beta",
)
@ -192,7 +205,7 @@ async def vertex_proxy_route(
model="",
auth_header=_auth_header,
gemini_api_key=None,
vertex_credentials=vertex_credentials,
vertex_credentials=vertex_credentials_str,
vertex_project=vertex_project,
vertex_location=vertex_location,
stream=False,