fix(vertex_ai/common_utils.py): fix handling constructed url with default vertex config

This commit is contained in:
Krrish Dholakia 2025-03-22 11:32:01 -07:00
parent b44b3bd36b
commit 06e69a414e
4 changed files with 98 additions and 17 deletions

View file

@ -303,11 +303,26 @@ def get_vertex_location_from_url(url: str) -> Optional[str]:
return match.group(1) if match else None
def replace_project_and_location_in_route(
requested_route: str, vertex_project: str, vertex_location: str
) -> str:
"""
Replace project and location values in the route with the provided values
"""
# Replace project and location values while keeping route structure
modified_route = re.sub(
r"/projects/[^/]+/locations/[^/]+/",
f"/projects/{vertex_project}/locations/{vertex_location}/",
requested_route,
)
return modified_route
def construct_target_url(
base_url: str,
requested_route: str,
default_vertex_location: Optional[str],
default_vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_project: Optional[str],
) -> httpx.URL:
"""
Allow user to specify their own project id / location.
@ -321,8 +336,12 @@ def construct_target_url(
"""
new_base_url = httpx.URL(base_url)
if "locations" in requested_route: # contains the target project id + location
updated_url = new_base_url.copy_with(path=requested_route)
return updated_url
if vertex_project and vertex_location:
requested_route = replace_project_and_location_in_route(
requested_route, vertex_project, vertex_location
)
return new_base_url.copy_with(path=requested_route)
"""
- Add endpoint version (e.g. v1beta for cachedContent, v1 for rest)
- Add default project id
@ -333,7 +352,7 @@ def construct_target_url(
vertex_version = "v1beta1"
base_requested_route = "{}/projects/{}/locations/{}".format(
vertex_version, default_vertex_project, default_vertex_location
vertex_version, vertex_project, vertex_location
)
updated_requested_route = "/" + base_requested_route + requested_route

View file

@ -463,13 +463,8 @@ async def vertex_proxy_route(
location=vertex_location,
)
if vertex_credentials is None:
raise Exception(
f"No matching vertex credentials found, for project_id: {vertex_project}, location: {vertex_location}. No default_vertex_config set either."
)
# Use headers from the incoming request if no vertex credentials are found
if vertex_credentials.vertex_project is None:
if vertex_credentials is None or vertex_credentials.vertex_project is None:
headers = dict(request.headers) or {}
verbose_proxy_logger.debug(
"default_vertex_config not set, incoming request headers %s", headers
@ -518,8 +513,8 @@ async def vertex_proxy_route(
updated_url = construct_target_url(
base_url=base_target_url,
requested_route=encoded_endpoint,
default_vertex_location=vertex_location,
default_vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_project=vertex_project,
)
# base_url = httpx.URL(base_target_url)
# updated_url = base_url.copy_with(path=encoded_endpoint)

View file

@ -150,19 +150,17 @@ class PassthroughEndpointRouter:
"""
Get the vertex credentials for the given project-id, location
"""
default_vertex_config: Optional[VertexPassThroughCredentials] = None
deployment_key = self._get_deployment_key(
project_id=project_id,
location=location,
)
if deployment_key is None:
return default_vertex_config
return self.default_vertex_config
if deployment_key in self.deployment_key_to_vertex_credentials:
return self.deployment_key_to_vertex_credentials[deployment_key]
else:
return default_vertex_config
return self.default_vertex_config
def _get_credential_name_for_provider(
self,

View file

@ -1,6 +1,7 @@
import json
import os
import sys
import traceback
from unittest import mock
from unittest.mock import MagicMock, patch
@ -262,3 +263,71 @@ class TestVertexAIPassThroughHandler:
target=f"https://{test_location}-aiplatform.googleapis.com/v1/projects/{test_project}/locations/{test_location}/publishers/google/models/gemini-1.5-flash:generateContent",
custom_headers={"Authorization": f"Bearer {test_token}"},
)
@pytest.mark.asyncio
async def test_vertex_passthrough_with_default_credentials(self, monkeypatch):
"""
Test that when no passthrough credentials are set, default credentials are used in the request
"""
from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import (
PassthroughEndpointRouter,
)
# Setup default credentials
default_project = "default-project"
default_location = "us-central1"
default_credentials = "default-creds"
pass_through_router = PassthroughEndpointRouter()
pass_through_router.default_vertex_config = VertexPassThroughCredentials(
vertex_project=default_project,
vertex_location=default_location,
vertex_credentials=default_credentials,
)
monkeypatch.setattr(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router",
pass_through_router,
)
# Use different project/location in request than the default
request_project = "non-existing-project"
request_location = "bad-location"
endpoint = f"/v1/projects/{request_project}/locations/{request_location}/publishers/google/models/gemini-1.5-flash:generateContent"
mock_request = Request(
scope={
"type": "http",
"method": "POST",
"path": endpoint,
"headers": {},
}
)
mock_response = Response()
with mock.patch(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._ensure_access_token_async"
) as mock_ensure_token, mock.patch(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._get_token_and_url"
) as mock_get_token, mock.patch(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route"
) as mock_create_route:
mock_ensure_token.return_value = ("test-auth-header", default_project)
mock_get_token.return_value = (default_credentials, "")
try:
await vertex_proxy_route(
endpoint=endpoint,
request=mock_request,
fastapi_response=mock_response,
)
except Exception as e:
traceback.print_exc()
print(f"Error: {e}")
# Verify default credentials were used
mock_create_route.assert_called_once_with(
endpoint=endpoint,
target=f"https://{default_location}-aiplatform.googleapis.com/v1/projects/{default_project}/locations/{default_location}/publishers/google/models/gemini-1.5-flash:generateContent",
custom_headers={"Authorization": f"Bearer {default_credentials}"},
)