(Feat) pass through vertex - allow using credentials defined on litellm router for vertex pass through (#8100)

* test_add_vertex_pass_through_deployment

* VertexPassThroughRouter

* fix use_in_pass_through

* VertexPassThroughRouter

* fix vertex_credentials

* allow using _initialize_deployment_for_pass_through

* test_add_vertex_pass_through_deployment

* _set_default_vertex_config

* fix verbose_proxy_logger

* fix use_in_pass_through

* fix _get_token_and_url

* test_get_vertex_location_from_url

* test_get_vertex_credentials_none

* run pt unit testing again

* fix add_vertex_credentials

* test_adding_deployments.py

* rename file
This commit is contained in:
Ishaan Jaff 2025-01-29 17:54:02 -08:00 committed by GitHub
parent 892581ffc3
commit b6d61ec22b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 490 additions and 19 deletions

View file

@ -15,7 +15,10 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.types.passthrough_endpoints.vertex_ai import * from litellm.types.passthrough_endpoints.vertex_ai import *
from .vertex_passthrough_router import VertexPassThroughRouter
router = APIRouter() router = APIRouter()
vertex_pass_through_router = VertexPassThroughRouter()
default_vertex_config: VertexPassThroughCredentials = VertexPassThroughCredentials() default_vertex_config: VertexPassThroughCredentials = VertexPassThroughCredentials()
@ -59,7 +62,14 @@ def set_default_vertex_config(config: Optional[dict] = None):
if isinstance(value, str) and value.startswith("os.environ/"): if isinstance(value, str) and value.startswith("os.environ/"):
config[key] = litellm.get_secret(value) config[key] = litellm.get_secret(value)
default_vertex_config = VertexPassThroughCredentials(**config) _set_default_vertex_config(VertexPassThroughCredentials(**config))
def _set_default_vertex_config(
vertex_pass_through_credentials: VertexPassThroughCredentials,
):
global default_vertex_config
default_vertex_config = vertex_pass_through_credentials
def exception_handler(e: Exception): def exception_handler(e: Exception):
@ -147,9 +157,6 @@ async def vertex_proxy_route(
[Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai) [Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai)
""" """
encoded_endpoint = httpx.URL(endpoint).path encoded_endpoint = httpx.URL(endpoint).path
import re
verbose_proxy_logger.debug("requested endpoint %s", endpoint) verbose_proxy_logger.debug("requested endpoint %s", endpoint)
headers: dict = {} headers: dict = {}
api_key_to_use = get_litellm_virtual_key(request=request) api_key_to_use = get_litellm_virtual_key(request=request)
@ -158,31 +165,37 @@ async def vertex_proxy_route(
api_key=api_key_to_use, api_key=api_key_to_use,
) )
vertex_project = None vertex_project: Optional[str] = (
vertex_location = None VertexPassThroughRouter._get_vertex_project_id_from_url(endpoint)
# Use headers from the incoming request if default_vertex_config is not set )
if default_vertex_config.vertex_project is None: vertex_location: Optional[str] = (
VertexPassThroughRouter._get_vertex_location_from_url(endpoint)
)
vertex_credentials = vertex_pass_through_router.get_vertex_credentials(
project_id=vertex_project,
location=vertex_location,
)
# Use headers from the incoming request if no vertex credentials are found
if vertex_credentials.vertex_project is None:
headers = dict(request.headers) or {} headers = dict(request.headers) or {}
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"default_vertex_config not set, incoming request headers %s", headers "default_vertex_config not set, incoming request headers %s", headers
) )
# extract location from endpoint, endpoint
# "v1beta1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.5-pro:generateContent"
match = re.search(r"/locations/([^/]+)", endpoint)
vertex_location = match.group(1) if match else None
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/" base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
headers.pop("content-length", None) headers.pop("content-length", None)
headers.pop("host", None) headers.pop("host", None)
else: else:
vertex_project = default_vertex_config.vertex_project vertex_project = vertex_credentials.vertex_project
vertex_location = default_vertex_config.vertex_location vertex_location = vertex_credentials.vertex_location
vertex_credentials = default_vertex_config.vertex_credentials vertex_credentials_str = vertex_credentials.vertex_credentials
# Construct base URL for the target endpoint
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/" base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
_auth_header, vertex_project = ( _auth_header, vertex_project = (
await vertex_fine_tuning_apis_instance._ensure_access_token_async( await vertex_fine_tuning_apis_instance._ensure_access_token_async(
credentials=vertex_credentials, credentials=vertex_credentials_str,
project_id=vertex_project, project_id=vertex_project,
custom_llm_provider="vertex_ai_beta", custom_llm_provider="vertex_ai_beta",
) )
@ -192,7 +205,7 @@ async def vertex_proxy_route(
model="", model="",
auth_header=_auth_header, auth_header=_auth_header,
gemini_api_key=None, gemini_api_key=None,
vertex_credentials=vertex_credentials, vertex_credentials=vertex_credentials_str,
vertex_project=vertex_project, vertex_project=vertex_project,
vertex_location=vertex_location, vertex_location=vertex_location,
stream=False, stream=False,

View file

@ -0,0 +1,120 @@
import json
import re
from typing import Dict, Optional
from litellm._logging import verbose_proxy_logger
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
VertexPassThroughCredentials,
)
class VertexPassThroughRouter:
"""
Vertex Pass Through Router for Vertex AI pass-through endpoints
- if request specifies a project-id, location -> use credentials corresponding to the project-id, location
- if request does not specify a project-id, location -> use credentials corresponding to the DEFAULT_VERTEXAI_PROJECT, DEFAULT_VERTEXAI_LOCATION
"""
def __init__(self):
"""
Initialize the VertexPassThroughRouter
Stores the vertex credentials for each deployment key
```
{
"project_id-location": VertexPassThroughCredentials,
"adroit-crow-us-central1": VertexPassThroughCredentials,
}
```
"""
self.deployment_key_to_vertex_credentials: Dict[
str, VertexPassThroughCredentials
] = {}
pass
def get_vertex_credentials(
self, project_id: Optional[str], location: Optional[str]
) -> VertexPassThroughCredentials:
"""
Get the vertex credentials for the given project-id, location
"""
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
default_vertex_config,
)
deployment_key = self._get_deployment_key(
project_id=project_id,
location=location,
)
if deployment_key is None:
return default_vertex_config
if deployment_key in self.deployment_key_to_vertex_credentials:
return self.deployment_key_to_vertex_credentials[deployment_key]
else:
return default_vertex_config
def add_vertex_credentials(
self,
project_id: str,
location: str,
vertex_credentials: str,
):
"""
Add the vertex credentials for the given project-id, location
"""
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
_set_default_vertex_config,
)
deployment_key = self._get_deployment_key(
project_id=project_id,
location=location,
)
if deployment_key is None:
verbose_proxy_logger.debug(
"No deployment key found for project-id, location"
)
return
vertex_pass_through_credentials = VertexPassThroughCredentials(
vertex_project=project_id,
vertex_location=location,
vertex_credentials=vertex_credentials,
)
self.deployment_key_to_vertex_credentials[deployment_key] = (
vertex_pass_through_credentials
)
verbose_proxy_logger.debug(
f"self.deployment_key_to_vertex_credentials: {json.dumps(self.deployment_key_to_vertex_credentials, indent=4, default=str)}"
)
_set_default_vertex_config(vertex_pass_through_credentials)
def _get_deployment_key(
self, project_id: Optional[str], location: Optional[str]
) -> Optional[str]:
"""
Get the deployment key for the given project-id, location
"""
if project_id is None or location is None:
return None
return f"{project_id}-{location}"
@staticmethod
def _get_vertex_project_id_from_url(url: str) -> Optional[str]:
"""
Get the vertex project id from the url
`https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent`
"""
match = re.search(r"/projects/([^/]+)", url)
return match.group(1) if match else None
@staticmethod
def _get_vertex_location_from_url(url: str) -> Optional[str]:
"""
Get the vertex location from the url
`https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent`
"""
match = re.search(r"/locations/([^/]+)", url)
return match.group(1) if match else None

View file

@ -4133,8 +4133,48 @@ class Router:
litellm_router_instance=self, model=deployment.to_json(exclude_none=True) litellm_router_instance=self, model=deployment.to_json(exclude_none=True)
) )
self._initialize_deployment_for_pass_through(
deployment=deployment,
custom_llm_provider=custom_llm_provider,
model=deployment.litellm_params.model,
)
return deployment return deployment
def _initialize_deployment_for_pass_through(
self, deployment: Deployment, custom_llm_provider: str, model: str
):
"""
Optional: Initialize deployment for pass-through endpoints if `deployment.litellm_params.use_in_pass_through` is True
Each provider uses diff .env vars for pass-through endpoints, this helper uses the deployment credentials to set the .env vars for pass-through endpoints
"""
if deployment.litellm_params.use_in_pass_through is True:
if custom_llm_provider == "vertex_ai":
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
vertex_pass_through_router,
)
if (
deployment.litellm_params.vertex_project is None
or deployment.litellm_params.vertex_location is None
or deployment.litellm_params.vertex_credentials is None
):
raise ValueError(
"vertex_project, vertex_location, and vertex_credentials must be set in litellm_params for pass-through endpoints"
)
vertex_pass_through_router.add_vertex_credentials(
project_id=deployment.litellm_params.vertex_project,
location=deployment.litellm_params.vertex_location,
vertex_credentials=deployment.litellm_params.vertex_credentials,
)
else:
verbose_router_logger.error(
f"Unsupported provider - {custom_llm_provider} for pass-through endpoints"
)
pass
pass
def add_deployment(self, deployment: Deployment) -> Optional[Deployment]: def add_deployment(self, deployment: Deployment) -> Optional[Deployment]:
""" """
Parameters: Parameters:

View file

@ -176,7 +176,7 @@ class GenericLiteLLMParams(BaseModel):
# Deployment budgets # Deployment budgets
max_budget: Optional[float] = None max_budget: Optional[float] = None
budget_duration: Optional[str] = None budget_duration: Optional[str] = None
use_in_pass_through: Optional[bool] = False
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
def __init__( def __init__(
@ -215,6 +215,8 @@ class GenericLiteLLMParams(BaseModel):
# Deployment budgets # Deployment budgets
max_budget: Optional[float] = None, max_budget: Optional[float] = None,
budget_duration: Optional[str] = None, budget_duration: Optional[str] = None,
# Pass through params
use_in_pass_through: Optional[bool] = False,
**params, **params,
): ):
args = locals() args = locals()
@ -276,6 +278,8 @@ class LiteLLM_Params(GenericLiteLLMParams):
# OpenAI / Azure Whisper # OpenAI / Azure Whisper
# set a max-size of file that can be passed to litellm proxy # set a max-size of file that can be passed to litellm proxy
max_file_size_mb: Optional[float] = None, max_file_size_mb: Optional[float] = None,
# will use deployment on pass-through endpoints if True
use_in_pass_through: Optional[bool] = False,
**params, **params,
): ):
args = locals() args = locals()

View file

@ -1729,6 +1729,7 @@ all_litellm_params = [
"max_fallbacks", "max_fallbacks",
"max_budget", "max_budget",
"budget_duration", "budget_duration",
"use_in_pass_through",
] + list(StandardCallbackDynamicParams.__annotations__.keys()) ] + list(StandardCallbackDynamicParams.__annotations__.keys())

View file

@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, Mock, patch
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system-path
import httpx import httpx
@ -23,6 +23,9 @@ from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
VertexPassThroughCredentials, VertexPassThroughCredentials,
default_vertex_config, default_vertex_config,
) )
from litellm.proxy.vertex_ai_endpoints.vertex_passthrough_router import (
VertexPassThroughRouter,
)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -167,3 +170,123 @@ async def test_set_default_vertex_config():
del os.environ["DEFAULT_VERTEXAI_LOCATION"] del os.environ["DEFAULT_VERTEXAI_LOCATION"]
del os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] del os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"]
del os.environ["GOOGLE_CREDS"] del os.environ["GOOGLE_CREDS"]
@pytest.mark.asyncio
async def test_vertex_passthrough_router_init():
"""Test VertexPassThroughRouter initialization"""
router = VertexPassThroughRouter()
assert isinstance(router.deployment_key_to_vertex_credentials, dict)
assert len(router.deployment_key_to_vertex_credentials) == 0
@pytest.mark.asyncio
async def test_get_vertex_credentials_none():
"""Test get_vertex_credentials with various inputs"""
from litellm.proxy.vertex_ai_endpoints import vertex_endpoints
setattr(vertex_endpoints, "default_vertex_config", VertexPassThroughCredentials())
router = VertexPassThroughRouter()
# Test with None project_id and location - should return default config
creds = router.get_vertex_credentials(None, None)
assert isinstance(creds, VertexPassThroughCredentials)
# Test with valid project_id and location but no stored credentials
creds = router.get_vertex_credentials("test-project", "us-central1")
assert isinstance(creds, VertexPassThroughCredentials)
assert creds.vertex_project is None
assert creds.vertex_location is None
assert creds.vertex_credentials is None
@pytest.mark.asyncio
async def test_get_vertex_credentials_stored():
"""Test get_vertex_credentials with stored credentials"""
router = VertexPassThroughRouter()
router.add_vertex_credentials(
project_id="test-project",
location="us-central1",
vertex_credentials="test-creds",
)
creds = router.get_vertex_credentials(
project_id="test-project", location="us-central1"
)
assert creds.vertex_project == "test-project"
assert creds.vertex_location == "us-central1"
assert creds.vertex_credentials == "test-creds"
@pytest.mark.asyncio
async def test_add_vertex_credentials():
"""Test add_vertex_credentials functionality"""
router = VertexPassThroughRouter()
# Test adding valid credentials
router.add_vertex_credentials(
project_id="test-project",
location="us-central1",
vertex_credentials="test-creds",
)
assert "test-project-us-central1" in router.deployment_key_to_vertex_credentials
creds = router.deployment_key_to_vertex_credentials["test-project-us-central1"]
assert creds.vertex_project == "test-project"
assert creds.vertex_location == "us-central1"
assert creds.vertex_credentials == "test-creds"
# Test adding with None values
router.add_vertex_credentials(
project_id=None, location=None, vertex_credentials="test-creds"
)
# Should not add None values
assert len(router.deployment_key_to_vertex_credentials) == 1
@pytest.mark.asyncio
async def test_get_deployment_key():
"""Test _get_deployment_key with various inputs"""
router = VertexPassThroughRouter()
# Test with valid inputs
key = router._get_deployment_key("test-project", "us-central1")
assert key == "test-project-us-central1"
# Test with None values
key = router._get_deployment_key(None, "us-central1")
assert key is None
key = router._get_deployment_key("test-project", None)
assert key is None
key = router._get_deployment_key(None, None)
assert key is None
@pytest.mark.asyncio
async def test_get_vertex_project_id_from_url():
"""Test _get_vertex_project_id_from_url with various URLs"""
# Test with valid URL
url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:streamGenerateContent"
project_id = VertexPassThroughRouter._get_vertex_project_id_from_url(url)
assert project_id == "test-project"
# Test with invalid URL
url = "https://invalid-url.com"
project_id = VertexPassThroughRouter._get_vertex_project_id_from_url(url)
assert project_id is None
@pytest.mark.asyncio
async def test_get_vertex_location_from_url():
"""Test _get_vertex_location_from_url with various URLs"""
# Test with valid URL
url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:streamGenerateContent"
location = VertexPassThroughRouter._get_vertex_location_from_url(url)
assert location == "us-central1"
# Test with invalid URL
url = "https://invalid-url.com"
location = VertexPassThroughRouter._get_vertex_location_from_url(url)
assert location is None

View file

@ -0,0 +1,170 @@
import sys, os
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from litellm import Router
from litellm.router import Deployment, LiteLLM_Params
from unittest.mock import patch
import json
def test_initialize_deployment_for_pass_through_success():
"""
Test successful initialization of a Vertex AI pass-through deployment
"""
router = Router(model_list=[])
deployment = Deployment(
model_name="vertex-test",
litellm_params=LiteLLM_Params(
model="vertex_ai/test-model",
vertex_project="test-project",
vertex_location="us-central1",
vertex_credentials=json.dumps(
{"type": "service_account", "project_id": "test"}
),
use_in_pass_through=True,
),
)
# Test the initialization
router._initialize_deployment_for_pass_through(
deployment=deployment,
custom_llm_provider="vertex_ai",
model="vertex_ai/test-model",
)
# Verify the credentials were properly set
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
vertex_pass_through_router,
)
vertex_creds = vertex_pass_through_router.get_vertex_credentials(
project_id="test-project", location="us-central1"
)
assert vertex_creds.vertex_project == "test-project"
assert vertex_creds.vertex_location == "us-central1"
assert vertex_creds.vertex_credentials == json.dumps(
{"type": "service_account", "project_id": "test"}
)
def test_initialize_deployment_for_pass_through_missing_params():
"""
Test initialization fails when required Vertex AI parameters are missing
"""
router = Router(model_list=[])
deployment = Deployment(
model_name="vertex-test",
litellm_params=LiteLLM_Params(
model="vertex_ai/test-model",
# Missing required parameters
use_in_pass_through=True,
),
)
# Test that initialization raises ValueError
with pytest.raises(
ValueError,
match="vertex_project, vertex_location, and vertex_credentials must be set",
):
router._initialize_deployment_for_pass_through(
deployment=deployment,
custom_llm_provider="vertex_ai",
model="vertex_ai/test-model",
)
def test_initialize_deployment_for_pass_through_unsupported_provider():
"""
Test initialization with an unsupported provider
"""
router = Router(model_list=[])
deployment = Deployment(
model_name="unsupported-test",
litellm_params=LiteLLM_Params(
model="unsupported/test-model",
use_in_pass_through=True,
),
)
# Should not raise an error, but log a warning
router._initialize_deployment_for_pass_through(
deployment=deployment,
custom_llm_provider="unsupported_provider",
model="unsupported/test-model",
)
def test_initialize_deployment_when_pass_through_disabled():
"""
Test that initialization simply exits when use_in_pass_through is False
"""
router = Router(model_list=[])
deployment = Deployment(
model_name="vertex-test",
litellm_params=LiteLLM_Params(
model="vertex_ai/test-model",
),
)
# This should exit without error, even with missing vertex parameters
router._initialize_deployment_for_pass_through(
deployment=deployment,
custom_llm_provider="vertex_ai",
model="vertex_ai/test-model",
)
# If we reach this point, the test passes as the method exited without raising any errors
assert True
def test_add_vertex_pass_through_deployment():
"""
Test adding a Vertex AI deployment with pass-through configuration
"""
router = Router(model_list=[])
# Create a deployment with Vertex AI pass-through settings
deployment = Deployment(
model_name="vertex-test",
litellm_params=LiteLLM_Params(
model="vertex_ai/test-model",
vertex_project="test-project",
vertex_location="us-central1",
vertex_credentials=json.dumps(
{"type": "service_account", "project_id": "test"}
),
use_in_pass_through=True,
),
)
# Add deployment to router
router.add_deployment(deployment)
# Get the vertex credentials from the router
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
vertex_pass_through_router,
)
# current state of pass-through vertex router
print("\n vertex_pass_through_router.deployment_key_to_vertex_credentials\n\n")
print(
json.dumps(
vertex_pass_through_router.deployment_key_to_vertex_credentials,
indent=4,
default=str,
)
)
vertex_creds = vertex_pass_through_router.get_vertex_credentials(
project_id="test-project", location="us-central1"
)
# Verify the credentials were properly set
assert vertex_creds.vertex_project == "test-project"
assert vertex_creds.vertex_location == "us-central1"
assert vertex_creds.vertex_credentials == json.dumps(
{"type": "service_account", "project_id": "test"}
)