mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(vertex_ai/common_utils.py): fix handling constructed url with default vertex config
This commit is contained in:
parent
b44b3bd36b
commit
06e69a414e
4 changed files with 98 additions and 17 deletions
|
@ -303,11 +303,26 @@ def get_vertex_location_from_url(url: str) -> Optional[str]:
|
|||
return match.group(1) if match else None
|
||||
|
||||
|
||||
def replace_project_and_location_in_route(
|
||||
requested_route: str, vertex_project: str, vertex_location: str
|
||||
) -> str:
|
||||
"""
|
||||
Replace project and location values in the route with the provided values
|
||||
"""
|
||||
# Replace project and location values while keeping route structure
|
||||
modified_route = re.sub(
|
||||
r"/projects/[^/]+/locations/[^/]+/",
|
||||
f"/projects/{vertex_project}/locations/{vertex_location}/",
|
||||
requested_route,
|
||||
)
|
||||
return modified_route
|
||||
|
||||
|
||||
def construct_target_url(
|
||||
base_url: str,
|
||||
requested_route: str,
|
||||
default_vertex_location: Optional[str],
|
||||
default_vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_project: Optional[str],
|
||||
) -> httpx.URL:
|
||||
"""
|
||||
Allow user to specify their own project id / location.
|
||||
|
@ -321,8 +336,12 @@ def construct_target_url(
|
|||
"""
|
||||
new_base_url = httpx.URL(base_url)
|
||||
if "locations" in requested_route: # contains the target project id + location
|
||||
updated_url = new_base_url.copy_with(path=requested_route)
|
||||
return updated_url
|
||||
if vertex_project and vertex_location:
|
||||
requested_route = replace_project_and_location_in_route(
|
||||
requested_route, vertex_project, vertex_location
|
||||
)
|
||||
return new_base_url.copy_with(path=requested_route)
|
||||
|
||||
"""
|
||||
- Add endpoint version (e.g. v1beta for cachedContent, v1 for rest)
|
||||
- Add default project id
|
||||
|
@ -333,7 +352,7 @@ def construct_target_url(
|
|||
vertex_version = "v1beta1"
|
||||
|
||||
base_requested_route = "{}/projects/{}/locations/{}".format(
|
||||
vertex_version, default_vertex_project, default_vertex_location
|
||||
vertex_version, vertex_project, vertex_location
|
||||
)
|
||||
|
||||
updated_requested_route = "/" + base_requested_route + requested_route
|
||||
|
|
|
@ -463,13 +463,8 @@ async def vertex_proxy_route(
|
|||
location=vertex_location,
|
||||
)
|
||||
|
||||
if vertex_credentials is None:
|
||||
raise Exception(
|
||||
f"No matching vertex credentials found, for project_id: {vertex_project}, location: {vertex_location}. No default_vertex_config set either."
|
||||
)
|
||||
|
||||
# Use headers from the incoming request if no vertex credentials are found
|
||||
if vertex_credentials.vertex_project is None:
|
||||
if vertex_credentials is None or vertex_credentials.vertex_project is None:
|
||||
headers = dict(request.headers) or {}
|
||||
verbose_proxy_logger.debug(
|
||||
"default_vertex_config not set, incoming request headers %s", headers
|
||||
|
@ -518,8 +513,8 @@ async def vertex_proxy_route(
|
|||
updated_url = construct_target_url(
|
||||
base_url=base_target_url,
|
||||
requested_route=encoded_endpoint,
|
||||
default_vertex_location=vertex_location,
|
||||
default_vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_project=vertex_project,
|
||||
)
|
||||
# base_url = httpx.URL(base_target_url)
|
||||
# updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
|
|
@ -150,19 +150,17 @@ class PassthroughEndpointRouter:
|
|||
"""
|
||||
Get the vertex credentials for the given project-id, location
|
||||
"""
|
||||
default_vertex_config: Optional[VertexPassThroughCredentials] = None
|
||||
|
||||
deployment_key = self._get_deployment_key(
|
||||
project_id=project_id,
|
||||
location=location,
|
||||
)
|
||||
|
||||
if deployment_key is None:
|
||||
return default_vertex_config
|
||||
return self.default_vertex_config
|
||||
if deployment_key in self.deployment_key_to_vertex_credentials:
|
||||
return self.deployment_key_to_vertex_credentials[deployment_key]
|
||||
else:
|
||||
return default_vertex_config
|
||||
return self.default_vertex_config
|
||||
|
||||
def _get_credential_name_for_provider(
|
||||
self,
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
@ -262,3 +263,71 @@ class TestVertexAIPassThroughHandler:
|
|||
target=f"https://{test_location}-aiplatform.googleapis.com/v1/projects/{test_project}/locations/{test_location}/publishers/google/models/gemini-1.5-flash:generateContent",
|
||||
custom_headers={"Authorization": f"Bearer {test_token}"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertex_passthrough_with_default_credentials(self, monkeypatch):
|
||||
"""
|
||||
Test that when no passthrough credentials are set, default credentials are used in the request
|
||||
"""
|
||||
from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import (
|
||||
PassthroughEndpointRouter,
|
||||
)
|
||||
|
||||
# Setup default credentials
|
||||
default_project = "default-project"
|
||||
default_location = "us-central1"
|
||||
default_credentials = "default-creds"
|
||||
|
||||
pass_through_router = PassthroughEndpointRouter()
|
||||
pass_through_router.default_vertex_config = VertexPassThroughCredentials(
|
||||
vertex_project=default_project,
|
||||
vertex_location=default_location,
|
||||
vertex_credentials=default_credentials,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router",
|
||||
pass_through_router,
|
||||
)
|
||||
|
||||
# Use different project/location in request than the default
|
||||
request_project = "non-existing-project"
|
||||
request_location = "bad-location"
|
||||
endpoint = f"/v1/projects/{request_project}/locations/{request_location}/publishers/google/models/gemini-1.5-flash:generateContent"
|
||||
|
||||
mock_request = Request(
|
||||
scope={
|
||||
"type": "http",
|
||||
"method": "POST",
|
||||
"path": endpoint,
|
||||
"headers": {},
|
||||
}
|
||||
)
|
||||
mock_response = Response()
|
||||
|
||||
with mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._ensure_access_token_async"
|
||||
) as mock_ensure_token, mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._get_token_and_url"
|
||||
) as mock_get_token, mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route"
|
||||
) as mock_create_route:
|
||||
mock_ensure_token.return_value = ("test-auth-header", default_project)
|
||||
mock_get_token.return_value = (default_credentials, "")
|
||||
|
||||
try:
|
||||
await vertex_proxy_route(
|
||||
endpoint=endpoint,
|
||||
request=mock_request,
|
||||
fastapi_response=mock_response,
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
print(f"Error: {e}")
|
||||
|
||||
# Verify default credentials were used
|
||||
mock_create_route.assert_called_once_with(
|
||||
endpoint=endpoint,
|
||||
target=f"https://{default_location}-aiplatform.googleapis.com/v1/projects/{default_project}/locations/{default_location}/publishers/google/models/gemini-1.5-flash:generateContent",
|
||||
custom_headers={"Authorization": f"Bearer {default_credentials}"},
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue