fix(vertex_ai/common_utils.py): fix handling constructed url with default vertex config

This commit is contained in:
Krrish Dholakia 2025-03-22 11:32:01 -07:00
parent b44b3bd36b
commit 06e69a414e
4 changed files with 98 additions and 17 deletions

View file

@ -303,11 +303,26 @@ def get_vertex_location_from_url(url: str) -> Optional[str]:
return match.group(1) if match else None return match.group(1) if match else None
def replace_project_and_location_in_route(
requested_route: str, vertex_project: str, vertex_location: str
) -> str:
"""
Replace project and location values in the route with the provided values
"""
# Replace project and location values while keeping route structure
modified_route = re.sub(
r"/projects/[^/]+/locations/[^/]+/",
f"/projects/{vertex_project}/locations/{vertex_location}/",
requested_route,
)
return modified_route
def construct_target_url( def construct_target_url(
base_url: str, base_url: str,
requested_route: str, requested_route: str,
default_vertex_location: Optional[str], vertex_location: Optional[str],
default_vertex_project: Optional[str], vertex_project: Optional[str],
) -> httpx.URL: ) -> httpx.URL:
""" """
Allow user to specify their own project id / location. Allow user to specify their own project id / location.
@ -321,8 +336,12 @@ def construct_target_url(
""" """
new_base_url = httpx.URL(base_url) new_base_url = httpx.URL(base_url)
if "locations" in requested_route: # contains the target project id + location if "locations" in requested_route: # contains the target project id + location
updated_url = new_base_url.copy_with(path=requested_route) if vertex_project and vertex_location:
return updated_url requested_route = replace_project_and_location_in_route(
requested_route, vertex_project, vertex_location
)
return new_base_url.copy_with(path=requested_route)
""" """
- Add endpoint version (e.g. v1beta for cachedContent, v1 for rest) - Add endpoint version (e.g. v1beta for cachedContent, v1 for rest)
- Add default project id - Add default project id
@ -333,7 +352,7 @@ def construct_target_url(
vertex_version = "v1beta1" vertex_version = "v1beta1"
base_requested_route = "{}/projects/{}/locations/{}".format( base_requested_route = "{}/projects/{}/locations/{}".format(
vertex_version, default_vertex_project, default_vertex_location vertex_version, vertex_project, vertex_location
) )
updated_requested_route = "/" + base_requested_route + requested_route updated_requested_route = "/" + base_requested_route + requested_route

View file

@ -463,13 +463,8 @@ async def vertex_proxy_route(
location=vertex_location, location=vertex_location,
) )
if vertex_credentials is None:
raise Exception(
f"No matching vertex credentials found, for project_id: {vertex_project}, location: {vertex_location}. No default_vertex_config set either."
)
# Use headers from the incoming request if no vertex credentials are found # Use headers from the incoming request if no vertex credentials are found
if vertex_credentials.vertex_project is None: if vertex_credentials is None or vertex_credentials.vertex_project is None:
headers = dict(request.headers) or {} headers = dict(request.headers) or {}
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"default_vertex_config not set, incoming request headers %s", headers "default_vertex_config not set, incoming request headers %s", headers
@ -518,8 +513,8 @@ async def vertex_proxy_route(
updated_url = construct_target_url( updated_url = construct_target_url(
base_url=base_target_url, base_url=base_target_url,
requested_route=encoded_endpoint, requested_route=encoded_endpoint,
default_vertex_location=vertex_location, vertex_location=vertex_location,
default_vertex_project=vertex_project, vertex_project=vertex_project,
) )
# base_url = httpx.URL(base_target_url) # base_url = httpx.URL(base_target_url)
# updated_url = base_url.copy_with(path=encoded_endpoint) # updated_url = base_url.copy_with(path=encoded_endpoint)

View file

@ -150,19 +150,17 @@ class PassthroughEndpointRouter:
""" """
Get the vertex credentials for the given project-id, location Get the vertex credentials for the given project-id, location
""" """
default_vertex_config: Optional[VertexPassThroughCredentials] = None
deployment_key = self._get_deployment_key( deployment_key = self._get_deployment_key(
project_id=project_id, project_id=project_id,
location=location, location=location,
) )
if deployment_key is None: if deployment_key is None:
return default_vertex_config return self.default_vertex_config
if deployment_key in self.deployment_key_to_vertex_credentials: if deployment_key in self.deployment_key_to_vertex_credentials:
return self.deployment_key_to_vertex_credentials[deployment_key] return self.deployment_key_to_vertex_credentials[deployment_key]
else: else:
return default_vertex_config return self.default_vertex_config
def _get_credential_name_for_provider( def _get_credential_name_for_provider(
self, self,

View file

@ -1,6 +1,7 @@
import json import json
import os import os
import sys import sys
import traceback
from unittest import mock from unittest import mock
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
@ -262,3 +263,71 @@ class TestVertexAIPassThroughHandler:
target=f"https://{test_location}-aiplatform.googleapis.com/v1/projects/{test_project}/locations/{test_location}/publishers/google/models/gemini-1.5-flash:generateContent", target=f"https://{test_location}-aiplatform.googleapis.com/v1/projects/{test_project}/locations/{test_location}/publishers/google/models/gemini-1.5-flash:generateContent",
custom_headers={"Authorization": f"Bearer {test_token}"}, custom_headers={"Authorization": f"Bearer {test_token}"},
) )
@pytest.mark.asyncio
async def test_vertex_passthrough_with_default_credentials(self, monkeypatch):
"""
Test that when no passthrough credentials are set, default credentials are used in the request
"""
from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import (
PassthroughEndpointRouter,
)
# Setup default credentials
default_project = "default-project"
default_location = "us-central1"
default_credentials = "default-creds"
pass_through_router = PassthroughEndpointRouter()
pass_through_router.default_vertex_config = VertexPassThroughCredentials(
vertex_project=default_project,
vertex_location=default_location,
vertex_credentials=default_credentials,
)
monkeypatch.setattr(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router",
pass_through_router,
)
# Use different project/location in request than the default
request_project = "non-existing-project"
request_location = "bad-location"
endpoint = f"/v1/projects/{request_project}/locations/{request_location}/publishers/google/models/gemini-1.5-flash:generateContent"
mock_request = Request(
scope={
"type": "http",
"method": "POST",
"path": endpoint,
"headers": {},
}
)
mock_response = Response()
with mock.patch(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._ensure_access_token_async"
) as mock_ensure_token, mock.patch(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._get_token_and_url"
) as mock_get_token, mock.patch(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route"
) as mock_create_route:
mock_ensure_token.return_value = ("test-auth-header", default_project)
mock_get_token.return_value = (default_credentials, "")
try:
await vertex_proxy_route(
endpoint=endpoint,
request=mock_request,
fastapi_response=mock_response,
)
except Exception as e:
traceback.print_exc()
print(f"Error: {e}")
# Verify default credentials were used
mock_create_route.assert_called_once_with(
endpoint=endpoint,
target=f"https://{default_location}-aiplatform.googleapis.com/v1/projects/{default_project}/locations/{default_location}/publishers/google/models/gemini-1.5-flash:generateContent",
custom_headers={"Authorization": f"Bearer {default_credentials}"},
)